Which is the fastest method to calculate means square error in large image dataset?

260 Views Asked by At

I'm trying to calculate the mean square error in an image dataset(CIFAR-10). I have a numpy array of dimension 5*10000*32*32*3 which is, in words, 5 batches of 10000 images each with dimensions of 32*32*3. These images belong to 10 categories of images. I have calculated average of each class and now I'm trying to calculate the mean square error of each of the 50000 images wrt the 10 average images. Here is the code:

for i in range(0, 5):
  for j in range(0, 10000):
      min_diff, min_class = float('inf'), 0
      for avg in class_avg:  # avg class comprises of 10 average images
          temp = mse(avg[1], images[i][j])
          if temp < min_diff:
              min_diff = temp
              min_class = avg[0]
      train_pred[i][j] = min_class

Problem: Is there any way to make it faster. Any numpy magic? Thank you.

1

There are 1 best solutions below

2
On BEST ANSWER

You can use expand_dims and tile.

There are many ways of expanding the dimension of an array, I will use one of them, which is something like [:,None,:], this adds a new axis in the middle.

Below is an example of how you can combine the two methods to fulfill your task:

test = np.ones((5,100,32,32,3)) # batches of images 
average = np.ones((10,32,32,3)) # the 10 images 
average = average[None,None,...] # reshape to (1,1,10,32,32,3)

test = test[:,:,None,...] # insert an axis 
test = np.tile(test,(1,1,10,1,1,1)) # reshape to (5,100,10,32,32,3)
print(test.shape,average.shape)

mse = ((test-average)**2).mean(axis=(3,4,5))
class_idx = np.argmin(mse,axis=-1)

UPDATE

The purpose of using expand_dims and tile is to avoid using a for-loop. However, the np.tile operation will create 10 replicates of the original array, this will definitely hurt the performance if the array is large. To avoid using np.tile, you can try the code below:

labels = np.empty((5,100,10))
average = np.ones((10,32,32,3))
average = average[None,...]

test = np.ones((5,100,32,32,3))

for ind in range(10):
    labels[...,ind] = ((test-average[:,ind,...])**2).mean(axis=(2,3,4))
labels = np.argmin(labels,axis=-1)