I made a simple replay buffer that when I sample from it gives me the error TypeError: 'type' object is not iterable
import collections
import numpy as np
Experience = collections.namedtuple("Experience", field_names=["state", "action", "reward", "done", "next_state"])
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = collections.deque(maxlen=capacity)
def __len__(self):
return len(self.buffer)
def add_exp(self, exp: Experience):
self.buffer.append(exp)
def sample(self, batch_size):
idxs = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in idxs])
return np.array(states), np.array(actions), \
np.array(rewards, dtype=np.float32), \
np.array(dones, dtype=np.uint8), \
np.array(next_states)
When I print the type of self.buffer[0] it gives 'type'
but shouldn't it be ReplayBuffer.Experience
?
You're adding a type to your list, not an instance of the type. What you're doing is essentially the same as this:
Hopefully this makes it clearer what the problem is. You need to create an instance of
Experience
first, then add that instance to the list. Something like this:Where all the
the_
variables are the data that you want to instantiate the object with.Also note, the more modern way to write
Experience
is withclass
andNamedTuple
:Where the
_type
are the types of each field. This allows type checkers to help you catch type errors.