Tensorflow: Tensor cannot be accessed inside test_step function

54 Views Asked by At

I am writing a VAE that uses a PID algorithm to tune the KL-divergence of the VAE (see Shao et al 2020). In short, before calculating the total loss, the KL-divergence term is multiplied by a term beta. This term changes over the course of training according to the following equation:

enter image description here

e(t) is the difference between the desired KL-divergence and the current KL-divergence. I don't think the first term on the RHS is a problem, but the second term causes some problems. My implementation of this acting in the test_step() function causes the following error:

The tensor <tf.Tensor 'add_2:0' shape=() dtype=float32> cannot be accessed from here, because it was defined in FuncGraph(name=train_function, id=139749247929504), which is out of scope.

Here is my current implementation: In the __init__ of my model, I initialize an empty TensorArray to keep track of all e(t) from 0 to t, so that I can sum them later. Here is what the train_step() function looks like:

def train_step(self, data: npt.ArrayLike) -> dict:
        
        # Set gradient context manager
        with tf.GradientTape() as tape:
            # Get latent values
            mean, log_variance, sample = self.encoder(data)
            # Reconstruct from the sample
            reconstruction = self.decoder(sample)
            # Calculcate reconstruction loss
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.categorical_crossentropy(data, reconstruction), axis=0
                )
            )
            # Calculate KL Loss
            kl_loss = self.kullback_leibler_loss(mean=mean, log_variance=log_variance)
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))

            # Get error vs desired KL
            error = self.desired_kl - kl_loss
            # Add new error to TensorArray of errors and add to iteration
            self.beta_errors = self.beta_errors.write(
                self.beta_iteration_counter, error
            )
            # Calculate proportional term
            proportional_term = self.proportional_kl / (1 + tf.exp(error))
            # Calculate integral term
            integral_term = self.integral_kl * tf.reduce_sum(
                self.beta_errors.stack()
            )
            # Get control score
            control_score = proportional_term - integral_term + self.derivative_kl

            # Calculate total loss
            total_loss = reconstruction_loss + control_score * kl_loss

        # Apply gradient
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        # Update losses
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.kl_beta_tracker.update_state(control_score)

        # Return dictionary of losses
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
            "beta_score": self.kl_beta_tracker.result(),
        }

For reference, here is the call() function:

    def call(self, inputs):
        samples = self.encoder(inputs)
        self.beta_iteration_counter += 1
        return self.decoder(samples[2])

The way I wanted to implement the inclusion of beta(t) in the test_step() function is to do the following in init:

def __init__():
        ...
        # ^^^ all other init stuff
        self.betas = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=False)

Then near the end of train_step():

def train_step():
        ...
        # ^^^ all the other train_step stuff
        self.betas = self.betas.write(self.beta_iteration_counter, control_score)
        ...

Lastly:

def test_step(self, data: npt.ArrayLike):

        validation_data, _ = data
        mean, log_variance, sample = self.encoder(validation_data)
        reconstruction = self.decoder(sample)
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.categorical_crossentropy(validation_data, reconstruction),
                axis=0,
            )
        )
        # Get control score
        control_score = self.betas.read(self.beta_iteration_counter)

        # Calculate KL Loss
        kl_loss = self.kullback_leibler_loss(mean=mean, log_variance=log_variance)
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))

        # Calculate total loss
        total_loss = reconstruction_loss + control_score * kl_loss
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "total_loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

This causes the error seem above. Without this, the training works just fine. I have also tried to put the test_step() funciton under a @tf.function decorator. So at this point I am not sure how to get the correct validation loss to be calculated during training.

0

There are 0 best solutions below