Visit count in MCTS search keeps being zero

51 Views Asked by At

I am coding a AI project that involves a MCTS search. After each iteration the search is supposed to backpropagate the data back and increase the visit_count of each parent also by 1, which is in this function here:

def backpropagate(self, value):
    self.value_sum += value
    self.visit_count += 1
        
    value = self.game.get_opponent_value(value)
    if self.parent is not None:
        self.parent.backpropagate(value)

The function is being called here in the search():

@torch.no_grad()
def search(self, state):
    print(type(state))
    root = Node(self.game, self.args, state, visit_count=1)
    
    policy, _ = self.model(
        torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0)
    )

    policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
    policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
        * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)
    
    valid_moves = self.game.get_valid_moves(state)
    policy *= valid_moves
    policy /= np.sum(policy)
    root.expand(policy)
    
    for search in range(self.args['num_searches']):
        node = root
        
        while node.is_fully_expanded():
            node = node.select()
            
        value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
        value = self.game.get_opponent_value(value)
        
        if not is_terminal:
            policy, value = self.model(
                torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
            )
            policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
            valid_moves = self.game.get_valid_moves(node.state)
            # print(policy) if random.randint(1, 100) == 100 else print(None)
            policy *= valid_moves
            policy /= np.sum(policy)
            
            value = value.item()
            
            node.expand(policy)
            
        node.backpropagate(value)    
        
        
    action_probs = np.zeros(self.game.action_size)
    for child in root.children:
        action_probs[child.action_taken] = child.visit_count
    action_probs /= np.sum(action_probs)
    return action_probs

However, when the code runs the visit counts for all but one child of the root node is always zero.

Can anyone help? This problem has been persisting for a long time and I can't seem to find a solution.

I tried checking the visit_count in the backpropagate function when the grandparent == None (parent is root node), and it returned some valid values for visit_count, but it always disappears after

def backpropagate(self, value):
    self.value_sum += value
    self.visit_count += 1
    # print(self.parent.visit_count) if self.parent is not None and self.parent.parent is not None else print(None)
    
    value = -value
    if self.parent is not None:
        if self.parent.parent is None:
            print(self.visit_count)
        self.parent.backpropagate(value) 
0

There are 0 best solutions below