how to split a tiled pytorch tensor into a bunch of little tensors?

47 Views Asked by At

I have a pytorch tensor that I have tiled so it looks like a matrix of lots of little matrices. I would like to split it into a list of lists of those matrixes as I have additional info that I need to associate with them. I would like a "pythonic" way of doing so.

   original_tensor.shappe == (24, 32)
   tensor_to_split = original_tensor.reshape(6, 8, 4, 4) # not sure about this reshape
   tensor_to_split.shape == (6, 8, 4, 4)  # it's a 6x8 matrix of 4x4 matrices
   desired_tensors = list(list(tensor((4,4))) # where the first list has 6 entries of lists of 8 tensors
   # some code vaguely like:
   for r in range(6):
       clist = []
       for c in range(8): 
           clist.append(tensor_to_split[r][c]) # I don't know how to write this line
       desired_tensors.append(clist)

I haven't tried the above code and am not certain it would even compile due to the odd subscripting operation I want to do. Maybe I need a notation for slices of slices.original

I have done similar code which does nested for loops to append lists and then append lists of lists, so that part is "ok"

0

There are 0 best solutions below