Fix a legend in an animation created by celluloid

951 Views Asked by At

I want to animate the process of finding the minimum point of a function by different gradient descent optimization methods. For this purpose, I am using matplotlib and celluloid packages. The problem is that it is not possible to fix the legend of the plot in animation and in each loop a new legend is added below the previous legend as you can see in the figure below. is there any way to fix the legend and avoid this problem?

from celluloid import Camera
fig,ax = plt.subplots(1, 1,figsize=(10, 10))
camera = Camera(fig)
for i in range(path1.shape[1])
  ax.contour(x_mesh, y_mesh, z, levels=np.logspace(0, 5, 35), norm=LogNorm(), cmap=plt.cm.jet)
  ax.plot(*minima_, 'r*', markersize=18)

  line, = ax.plot([], [], 'k', label='Simple SGD', lw=2)
  point, = ax.plot([], [], 'ko')
  line.set_data(path1[::,:i])
  point.set_data(path1[::,i-1:i])

  line, = ax.plot([], [], 'r', label='SGD with momentum', lw=2)
  point, = ax.plot([], [], 'ro')
  line.set_data(*path2[::,:i])
  point.set_data(*path2[::,i-1:i])

  line, = ax.plot([], [], 'g', label='SGD with Nesterov', lw=2)
  point, = ax.plot([], [], 'go')
  line.set_data(*path3[::,:i])
  point.set_data(*path3[::,i-1:i])

  line, = ax.plot([], [], 'b', label='SGD with Adagrad', lw=2)
  point, = ax.plot([], [], 'bo')
  line.set_data(*path4[::,:i])
  point.set_data(*path4[::,i-1:i])

  line, = ax.plot([], [], 'c', label='SGD with Adadelta', lw=2)
  point, = ax.plot([], [], 'co')
  line.set_data(*path5[::,:i])
  point.set_data(*path5[::,i-1:i]) 

  line, = ax.plot([], [], 'm', label='SGD with RMSprob', lw=2)
  point, = ax.plot([], [], 'mo')
  line.set_data(*path6[::,:i])
  point.set_data(*path6[::,i-1:i])

  line, = ax.plot([], [], 'y', label='SGD with Adam', lw=2)
  point, = ax.plot([], [], 'yo')
  line.set_data(*path7[::,:i])
  point.set_data(*path7[::,i-1:i])

  line, = ax.plot([], [], 'y', label='SGD with Adamax', lw=2)
  point, = ax.plot([], [], 'y*')
  line.set_data(*path8[::,:i])
  point.set_data(*path8[::,i-1:i])

  line, = ax.plot([], [], 'k', label='SGD with Nadam', lw=2)
  point, = ax.plot([], [], 'kp')
  line.set_data(*path9[::,:i])
  point.set_data(*path9[::,i-1:i])

  line, = ax.plot([], [], 'r', label='SGD with AMSGrad', lw=2)
  point, = ax.plot([], [], 'rD')
  line.set_data(*path10[::,:i])
  point.set_data(*path10[::,i-1:i])

  ax.legend(loc='upper left') 
  camera.snap()
animation = camera.animate()
animation.save('2D_animation_overlap.gif', writer='imagemagick')

enter image description here

1

There are 1 best solutions below

2
William Miller On BEST ANSWER

The best practice here would be to create a custom legend instead of automatically generating a legend, in this case that could be done by

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

labels = ['Single SGD', 'SGD with momentum', 'SGD with Nesterov', 
          'SGD with Adagrad', 'SGD with Adadelta', 'SGD with RMSprob', 'SGD with Adam', 
          'SGD with Adamax', 'SGD with Nadam', 'SGD with AMSgrad']
colors = ['k', 'r', 'g', 'b', 'c', 'm', 'y', 'y', 'k', 'r']
handles = []
for c, l in zip(colors, labels):
    handles.append(Line2D([0], [0], color = c, label = l))

plt.legend(handles = handles, loc = 'upper left')

which will give you a legend like this:

enter image description here

You don’t need to have any of this in the loop, you can do it before or after and it will still work. It will also work in the loop but redrawing the legend each time is unnecessary.

It would also suffice to simply guard the legend creation with an if statement instead of creating the legend manually. I.e.

    # ...
    if i == 0:
        ax.legend(loc = 'upper left')

But I would recommend against the practice of goading the automatic legend generation in favor of directly creating the legend.