I am trying to write a DC Gan network using Google's Jax machine learning library. To do this, I created objects to serve as the discriminator and generator, however, as I was testing the discriminator, I got the error:
TypeError: Argument '<__main__.Discriminator object at 0x7fdfa5c6ffd0>' of type <class '__main__.Discriminator'> is not a valid JAX type
I looked through the examples on the Jax github page, and, from what I saw, none of the examples there use objects, which leads me to hypothesize that it is probably just not possible to use objects with Jax. But if this is the case, I don't really understand why the use of objects wouldn't be possible, and would this be something that will be implemented in the future? Am I just naively overlooking something?
Here is my Discriminator object:
class Discriminator():
def __init__(self):
self.step_size = 0.0001
self.image_shape = (256,256,3)
self.params = []
num_layers = 6
num_filters = 64
filter_size = 4
self.params.append(create_conv_layer(3,
num_filters,
filter_size,
filter_size,
random.PRNGKey(0)))
for l in range(1, num_layers):
self.params.append(create_conv_layer(64*2**(l-1),
64*2**l,
filter_size,
filter_size,
random.PRNGKey(0)))
self.params.append(create_conv_layer(64*2**num_filters,
1,
filter_size,
filter_size,
random.PRNGKey(0)))
def predict(self):
activations = image
for w, b in params[:-1]:
outputs = conv_forward(activations,w,b,stride=2)
outputs = batch_normalization(outputs)
activations = leaky_relu(outputs)
final_w, final_b = params[-1]
return sigmoid(conv_forward(activations,final_w,final_b,))
def batched_predict(self, images):
shape = [None] + list(self.image_shape)
return vmap(self.predict, in_axes=shape)(self.params, images)
def loss(self, params, images, targets):
preds = self.batched_predict(params, images)
return -np.sum(preds * targets)
def accuracy(self, images, targets):
predicted_class = np.round(np.ravel(batched_predict(images)))
return np.mean(predicted_class == target_class)
@jit
def update(self, params, x, y):
grads = grad(self.loss)(params, x, y)
return [(w - self.step_size * dw, b - self.step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
And I update the parameters here:
num_epochs = 5
batch_size = 64
steps_per_epoch = train_images.shape[0] // batch_size
discrim = Discriminator()
params = discrim.params
print("lets-a-go!")
for epoch in range(num_epochs):
start_time = time.time()
for step in range(steps_per_epoch):
x, y = simple_data_generator(batch_size)
params = discrim.update(params, x, y)
epoch_time = time.time() - start_time
train_acc = discrim.accuracy(train_images, train_labels)
test_acc = discrim.accuracy(test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))