Writing phenaki video to file: Expected numpy array with ndim `3` but got `4`

429 Views Asked by At

I'm trying to write the output of the Phenaki make_video to an mp4 file. I'm using this Phenaki implementation from github https://github.com/lucidrains/phenaki-pytorch/search?q=make_video

phenaki = Phenaki(
    cvivit = cvivit,
    maskgit = maskgit
)


entire_video, scenes = make_video(phenaki, texts = [
    'blah blah',
], num_frames = (17, 14, 14), prime_lengths = (5, 5))

entire_video.shape # (1, 3, 17 + 14 + 14 = 45, 256, 256)
torchvision.io.write_video(filename= "test.mp4", video_array= entire_video, fps=24)

The error I'm getting is

  File "/.../GitHub/phenaki-pytorch/run.py", line 49, in <module>
    torchvision.io.write_video(filename= "test.mp4", video_array= entire_video, fps=24)
  File "/opt/homebrew/lib/python3.10/site-packages/torchvision/io/video.py", line 132, in write_video
    frame = av.VideoFrame.from_ndarray(img, format="rgb24")
  File "av/video/frame.pyx", line 408, in av.video.frame.VideoFrame.from_ndarray
  File "av/utils.pyx", line 72, in av.utils.check_ndarray
ValueError: Expected numpy array with ndim `3` but got `4`

What am I doing wrong? Why is the numpy array expected to be 3 dimensions for av.VideoFrame.from_ndarray?

1

There are 1 best solutions below

0
On

According to write_video documentation, video_array argument format is "tensor containing the individual frames, as a uint8 tensor in [T, H, W, C] format".

The dimensions of entire_video is (1, 3, 45, 256, 128), so there are 5 dimensions instead of 4 dimensions.
The exception says ndim 3 but got 4 (not 4 and 5) because the dimensions mismatch is detected in an internal loop.

The order of the dimensions is also wrong (3 applies number of color channels, should be the last dimension).
The type of entire_video is also wrong - the type is float32 instead of uint8.
Assuming the entire_video resides in the GPU memory, we also have to copy the tensor to the CPU memory before using write_video.


Before using write_video, we may apply the following stages:

  • Copy the video from the GPU memory to CPU memory (and remove redundant axis):

     entire_video = entire_video[0].detach().cpu()
    
  • Convert from float32 to uint8 applying offset and scale.
    The following code uses global minimum and maximum (the conversion is not optimal - used as example):

     min_val = entire_video.min()
     max_val = entire_video.max()
     entire_video_as_uint8 = ((entire_video - min_val) * 255/(max_val min_val)).to(torch.uint8)
    
  • Reorder the axes to be ordered as [T, H, W, C]:

    • First axis applies frame index (shape value is 45 when there are 45 video frames).

    • Second axis applies row index (shape value is 256 when there are 256 rows in each frame).

    • Third axis applies column index (shape value is 128 when there are 128 columns in each frame).

    • Fourth axis applies applies color channel (shape value is 3, because there are 3 color channels - red, green and blue).

        vid_arr = torch.permute(entire_video_as_uint8, (1, 2, 3, 0))
      

Complete code sample:

import torch
from phenaki_pytorch import CViViT, MaskGit, Phenaki
from phenaki_pytorch import make_video
import torchvision

maskgit = MaskGit(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
)

cvivit = CViViT(
    dim = 512,
    codebook_size = 5000,
    image_size = (256, 128),  # video with rectangular screen allowed
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
)

phenaki = Phenaki(
    cvivit = cvivit,
    maskgit = maskgit
).cuda()

entire_video, scenes = make_video(phenaki, texts = [
    'blah blah'
], num_frames=(45, 14, 14), prime_lengths=(5, 5))

print(entire_video.shape)  # (1, 3, 45, 256, 128)

# Copy the video from the GPU memory to CPU memory.
# Apply entire_video[0] for removing redundant axis.
entire_video = entire_video[0].detach().cpu()  # https://stackoverflow.com/a/66754525/4926757

# Convert from float32 to uint8, use global minimum and global maximum - this is not the best solution
min_val = entire_video.min()
max_val = entire_video.max()
entire_video_as_uint8 = ((entire_video - min_val) * 255/(max_val-min_val)).to(torch.uint8)

# https://pytorch.org/vision/stable/generated/torchvision.io.write_video.html
# video_array - (Tensor[T, H, W, C]) – tensor containing the individual frames, as a uint8 tensor in [T, H, W, C] format
# https://pytorch.org/docs/stable/generated/torch.permute.html
vid_arr = torch.permute(entire_video_as_uint8, (1, 2, 3, 0))  # Reorder the axes to be ordered as [T, H, W, C]


print(vid_arr.shape)  # (45, 3, 256, 128)
torchvision.io.write_video(filename="test.mp4", video_array=vid_arr, fps=24)

After all that, the created video file looks like random noise...

enter image description here

It looks like this is the output of make_video and not related to the subject of the post.