Is it possible to use objects with Google's Jax machine learning library

1.4k Views Asked by At

I am trying to write a DC Gan network using Google's Jax machine learning library. To do this, I created objects to serve as the discriminator and generator, however, as I was testing the discriminator, I got the error:

    TypeError: Argument '<__main__.Discriminator object at 0x7fdfa5c6ffd0>' of type <class '__main__.Discriminator'> is not a valid JAX type

I looked through the examples on the Jax github page, and, from what I saw, none of the examples there use objects, which leads me to hypothesize that it is probably just not possible to use objects with Jax. But if this is the case, I don't really understand why the use of objects wouldn't be possible, and would this be something that will be implemented in the future? Am I just naively overlooking something?

Here is my Discriminator object:

class Discriminator():
    def __init__(self):
        self.step_size = 0.0001
        self.image_shape = (256,256,3)
        self.params = []
        num_layers = 6
        num_filters = 64
        filter_size = 4
        self.params.append(create_conv_layer(3, 
                                             num_filters, 
                                             filter_size, 
                                             filter_size, 
                                             random.PRNGKey(0)))
        for l in range(1, num_layers):
            self.params.append(create_conv_layer(64*2**(l-1), 
                                                 64*2**l, 
                                                 filter_size,   
                                                 filter_size, 
                                                 random.PRNGKey(0)))
        self.params.append(create_conv_layer(64*2**num_filters, 
                                             1, 
                                             filter_size, 
                                             filter_size, 
                                             random.PRNGKey(0)))

    def predict(self):
        activations = image
        for w, b in params[:-1]:
            outputs = conv_forward(activations,w,b,stride=2)
            outputs = batch_normalization(outputs)
            activations = leaky_relu(outputs)
        final_w, final_b = params[-1]
        return sigmoid(conv_forward(activations,final_w,final_b,))

    def batched_predict(self, images):
        shape = [None] + list(self.image_shape)
        return vmap(self.predict, in_axes=shape)(self.params, images)

    def loss(self, params, images, targets):
        preds = self.batched_predict(params, images)
        return -np.sum(preds * targets)

    def accuracy(self, images, targets):
        predicted_class = np.round(np.ravel(batched_predict(images)))
        return np.mean(predicted_class == target_class)

    @jit
    def update(self, params, x, y):
        grads = grad(self.loss)(params, x, y)
        return [(w - self.step_size * dw, b - self.step_size * db)
                for (w, b), (dw, db) in zip(params, grads)]

And I update the parameters here:

num_epochs = 5
batch_size = 64
steps_per_epoch = train_images.shape[0] // batch_size
discrim = Discriminator()
params = discrim.params

print("lets-a-go!")
for epoch in range(num_epochs):
    start_time = time.time()
    for step in range(steps_per_epoch):
        x, y = simple_data_generator(batch_size)
        params = discrim.update(params, x, y)
    epoch_time = time.time() - start_time

    train_acc = discrim.accuracy(train_images, train_labels)
    test_acc = discrim.accuracy(test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))
0

There are 0 best solutions below