I am working with Instant-NGP TL. I am trying to calculate the optical flow using the output of the volume rendering. My approach to this is the following:
- Take sampled points along each ray calculated by the ray marcher.
- Take the weights calcuated for each point along the ray by the volume rendre
- The Sampled points are calculated in the world coordinates so They are first transformed to screen space using the pose
- Multiply the weights by the sampled points and sum all the points along each ray
- Calculate acc_map = weights. The acc_map is summs all the weights along each ray
- Multiply (1 - acc_map) * (rays_o + rays_d) where rays_o are the ray origins and rays_d are directions
- sum the weighted points with the weighted rays
- let's call the current frame i and the frame after it i+1 so I take the pose of frame i+1 and convert the result of the sum to the world space
- using the inverse pose of frame i+1 I convert back to the camera space
- using the intrinsics matrix I calculate the pixel position
The following approach is currently not working as the results calculated seems to be too far off from optical flow calculated by RAFT.
This is the code i used to do the calculation
H => image height
W => image width
focal => the focal length
c2w_current => the pose which convert from the camera space to the world space for the current image
c2w_shifted => the pose of the neighbour image (i+1) next to the current image (i) from camera space to world space
weights => the weights calculated by the volumen render
xyzs => the sampled points along each ray calculated by ray marching algorithm
rays => [#num_of_image, 6] the first three elements contain the ray origin and the last three elements contain the ray direction
rays_a => contain three values the ray_idx, start_idx, num_samples . the ray_idx gives the origin and direction in rays array, start_idx gives the starting index of the sampled points and weights arrays to get the last element sum start_idx with num_samples
grid => is 2d meshgrid which starts from 0 to H-1 and from 0 to W-1 and is used to calculate the optical flow
def calculate_flow_2(H, W, focal, c2w_current, c2w_shifted, weights, xyzs, rays, rays_a, grid, scale=1.0):
rays_d = rearrange(rays[:, 3:], 'n c -> n 1 c') @ rearrange(torch.inverse(c2w_current[..., :3]), 'n a b -> n b a')
rays_d = rearrange(rays_d, 'n 1 c -> n c')
rays_with_shifted_origins = rays[:, :3] - c2w_current[:, :, 3]
camera_space_rays = torch.cat([rays_with_shifted_origins, rays_d], -1).view(-1, 6)
c2w_current_4x4 = torch.zeros((c2w_current.shape[0], 4, 4), device=c2w_current.device)
c2w_current_4x4[: , 3, 3] = 1
c2w_current_4x4[:, :3, :4] = c2w_current[:, :3, :4]
w2c_current = torch.inverse(c2w_current_4x4)
idx_arr = torch.arange(0, xyzs.shape[0], device=rays.device)
acc_weights = torch.zeros((rays.shape[0]), device=xyzs.device)
for ray_idx, start_idx, num_samples in rays_a:
acc_weights[ray_idx] = torch.sum(weights[start_idx:start_idx + num_samples])
idx_arr[start_idx:start_idx + num_samples] = ray_idx
rotated_points = torch.matmul(w2c_current[idx_arr, :3, :3], xyzs[..., None])
shifted_points = torch.squeeze(rotated_points) + w2c_current[idx_arr, :3, 3]
weighted_points_no_sum = weights[..., None] * shifted_points
weighted_points = torch.zeros((rays.shape[0], 3), device=xyzs.device)
for ray_idx, start_idx, num_samples in rays_a:
weighted_points[ray_idx] = torch.sum(weighted_points_no_sum[start_idx:start_idx + num_samples], dim=0)
weighted_ray = (1 - acc_weights[..., None]) * (camera_space_rays[:, :3] + camera_space_rays[:, 3:])
camera_space_point = weighted_points + weighted_ray
c2w_shifted_4x4 = torch.zeros((c2w_shifted.shape[0], 4, 4), device=c2w_shifted.device)
c2w_shifted_4x4[:, 3, 3] = 1
c2w_shifted_4x4[:, :3, :4] = c2w_shifted[:, :3, :4]
w2c_shifted = torch.inverse(c2w_shifted_4x4)
rotated_points = torch.matmul(c2w_shifted_4x4[rays_a[..., 0], :3, :3], camera_space_point[..., None])
shifted_points = torch.squeeze(rotated_points) + c2w_shifted_4x4[rays_a[..., 0], :3, 3]
rotated_points = torch.matmul(w2c_shifted[rays_a[..., 0], :3, :3], shifted_points[..., None])
shifted_points = torch.squeeze(rotated_points) + w2c_shifted[rays_a[..., 0], :3, 3]
point_map = torch.zeros((rays.shape[0], 2), device=xyzs.device)
point_map[:, 0] = (shifted_points[:, 0] / shifted_points[:, 2]) * focal + W * 0.5
point_map[:, 1] = (shifted_points[:, 1] / shifted_points[:, 2]) * focal + H * 0.5
return point_map - grid
The values of the optical flow function should be close to the values estimated by RAFT