I have code that is simulating interactions between lots of particles. Using profiling, I've worked out that the function that's causing the most slowdown is a loop which iterates over all my particles and works out the time for the collision between each of them. This generates a symmetric matrix, which I then take the minimum value out of.
def find_next_collision(self, print_matrix = False):
"""
Sets up a matrix of collision times
Returns the indices of the balls in self.list_of_balls that are due to
collide next and the time to the next collision
"""
self.coll_time_matrix = np.zeros((np.size(self.list_of_balls), np.size(self.list_of_balls)))
for i in range(np.size(self.list_of_balls)):
for j in range(i+1):
if (j==i):
self.coll_time_matrix[i][j] = np.inf
else:
self.coll_time_matrix[i][j] = self.list_of_balls[i].time_to_collision(self.list_of_balls[j])
matrix = self.coll_time_matrix + self.coll_time_matrix.T
self.coll_time_matrix = matrix
ind = np.unravel_index(np.argmin(self.coll_time_matrix, axis = None), self.coll_time_matrix.shape)
dt = self.coll_time_matrix[ind]
if (print_matrix):
print(self.coll_time_matrix)
return dt, ind
This code is a method inside a class which defines the positions of all of the particles. Each of these particles is an object saved in self.list_of_balls
(which is a list). As you can see, I'm already only iterating over half of this matrix, but it's still quite a slow function. I've tried using numba, but this is a section of quite a large code, and I don't want to have to optimize every function with numba when this is the slow one.
Does anyone have any ideas for a more efficient way to write this function?
Thank you in advance!
First Question: How many particles do you have?
If you have a lot of particles: one improvement would be
often executed
if
s slow everything down. Avoid them in innerloops2nd question: is it necessary to compute this every time? Wouldn't it be faster to compute points in time and only refresh those lines and columns which were involved in a collision?
Edit:
Idea here is to initially calculate either the time left or (better solution) the timestamp for the collision as you already as well as the order. But instead to throw the calculated results away, you only update the values when needed. This way you need to calculate only 2*n instead of n^2/2 values.
Sketch:
In case you calculate time left in the matrix you have to subtract the time passed from the matrix. Also you somehow need to store the matrix and (optionally) min_times and order_coll.