Proper use of random seeds in Jax

92 Views Asked by At

I am using Jax for Reinforcement Learning with DQN. For a step in the environment, I am using two alternatives regarding the generation of random seeds. These two approaches lead to significantly different results. Why does this happen? Which approach aligns with the proper use of random seeds in Jax?

The first one is indicated by the Jax documentation:

rng, step_rng = jax.random.split(rng)
next_state, next_env_state, reward, terminated, info = env.step(step_rng, env_state, action.squeeze(), env_params)

The second one is according to an example of purejaxrl:

rng, _rng = jax.random.split(rng)
_rng, step_rng = jax.random.split(_rng)
next_state, next_env_state, reward, terminated, info = env.step(step_rng, env_state, action.squeeze(), env_params)

EDIT: This is a short piece of code used in my environment step function, which makes a step in a subenvironment.

state, env_state, _, _, rng = runner
q = self.agent_nn.apply(self.agent_params, state)
action = jnp.argmax(q)
rng, rng_step = jax.random.split(rng)
next_state, next_env_state, reward, terminated, info = 
self.env.step(rng_step, env_state, action, self.env_params)

SECOND EDIT: use of random seed within lax.scan

def train(rng):

    rng, network_init_rng = jax.random.split(rng)
    network = q_network(env.action_space(env_params).n)
    init_x = jnp.zeros((1, config["STATE_SIZE"]))
    network_params = network.init(network_init_rng, init_x)
    
    training = TrainState.create(apply_fn=network.apply,
                                 params=network_params,
                                 target_params=network_params,
                                 tx=tx)

    rng, _rng = jax.random.split(rng)
    _rng, reset_rng = jax.random.split(_rng)
    state, env_state = env.reset(reset_rng, env_params)

    @jit
    @scan_tqdm(config["TOTAL_STEPS"])
    def _run_step(runner, i_step):

        training, env_state, state, rng, buffer_state, i_episode = runner

        rng, *_rng = jax.random.split(rng, 3)
        random_q_rng, random_number_rng = _rng

        q_state = network.apply(training.params, state)
        random_number = jax.random.uniform(random_number_rng, minval=0, maxval=1, shape=(1,))
        exploitation = jnp.greater(random_number, config["EPS"])
        action = jnp.where(exploitation, jnp.argmax(q_state, 1), random_action)

        rng, _rng = jax.random.split(rng)
        _rng, step_rng = jax.random.split(_rng)
        next_state, next_env_state, reward, terminated, info = env.step(step_rng, env_state, action.squeeze(), env_params)
        
        return runner

    rng, _rng = jax.random.split(rng)
    runner = (training, env_state, state, _rng, buffer_state, 0)
    runner, metrics = lax.scan(_run_step, runner, jnp.arange(config["TOTAL_STEPS"]), config["TOTAL_STEPS"])

    return {"runner": runner}
0

There are 0 best solutions below