[Deep Q-Network]How to exclude ops at auto-differential of Tensorflow

570 Views Asked by At

I am trying to create a Deep Q-Network (DQN) similar to Deepmind DQN3.0 using Tensorflow, but I am having some difficulties. I think that the cause is TensorFlow's auto-differential approach.

Please see this pic. This is the architecture of DQN3.0.

DQN architecture

In supervised learning, in order to approximate the output of the network to the label, calculate the difference by the loss function, back propagate it, and update the parameters with the optimizer.

In DQN, the state that AI experienced in the past accumulated in memory is input to two neural networks of TergetNetwork and Network again, and the difference between the two networks is reflected in the Network.

The output of each network is not the probability that the total will be 1 but the expected value. TergetNetwork's output will include the discount rate(gamma) and the reward earned at that time.

And, looking at the implementation of DQN 3.0 (lua + torch), it compares the outputs of the current network to the action selected at that time, and backward propagates the difference directly by the backward method.

function nql:getQUpdate(args)
    local s, a, r, s2, term, delta
    local q, q2, q2_max

    s = args.s
    a = args.a
    r = args.r
    s2 = args.s2
    term = args.term

    -- The order of calls to forward is a bit odd in order
    -- to avoid unnecessary calls (we only need 2).

    -- delta = r + (1-terminal) * gamma * max_a Q(s2, a) - Q(s, a)
    term = term:clone():float():mul(-1):add(1)

    local target_q_net
    if self.target_q then
        target_q_net = self.target_network
    else
        target_q_net = self.network
    end

    -- Compute max_a Q(s_2, a).
    q2_max = target_q_net:forward(s2):float():max(2)

    -- Compute q2 = (1-terminal) * gamma * max_a Q(s2, a)
    q2 = q2_max:clone():mul(self.discount):cmul(term)

    delta = r:clone():float()

    if self.rescale_r then
        delta:div(self.r_max)
    end
    delta:add(q2)

    -- q = Q(s,a)
    local q_all = self.network:forward(s):float()
    q = torch.FloatTensor(q_all:size(1))
    for i=1,q_all:size(1) do
        q[i] = q_all[i][a[i]]
    end
    delta:add(-1, q)

    if self.clip_delta then
        delta[delta:ge(self.clip_delta)] = self.clip_delta
        delta[delta:le(-self.clip_delta)] = -self.clip_delta
    end

    local targets = torch.zeros(self.minibatch_size, self.n_actions):float()
    for i=1,math.min(self.minibatch_size,a:size(1)) do
        targets[i][a[i]] = delta[i]
    end

    if self.gpu >= 0 then targets = targets:cuda() end

    return targets, delta, q2_max
end


function nql:qLearnMinibatch()
    -- Perform a minibatch Q-learning update:
    -- w += alpha * (r + gamma max Q(s2,a2) - Q(s,a)) * dQ(s,a)/dw
    assert(self.transitions:size() > self.minibatch_size)

    local s, a, r, s2, term = self.transitions:sample(self.minibatch_size)

    local targets, delta, q2_max = self:getQUpdate{s=s, a=a, r=r, s2=s2,
    term=term, update_qmax=true}

    -- zero gradients of parameters
    self.dw:zero()

    -- get new gradient
    self.network:backward(s, targets)

For this reason, if you do not mind the speed of the compute block in the above figure, you can calculate it by using Numpy etc. on the CPU instead of using Tensorflow, and I can exclude it from auto-differentiation I am thinking.

In DQN3.0, backpropagation is only computed only from output layer of Network (in blue). However, with my model in Tensorflow it starts from the final op mul.

I want to start backpropagation from that same output layer as in DQN3.0 using Tensorflow.

I understand that i can get grads_and_vars using compute_gradients() optimizer method, and run a manual differential process created from scratch. But, I think that implementing such a differential of convolution layer is very difficult to me.

Can I exclude compute block ops at auto-differential using Tensorflow functions or something? Or are there other methods to solve this?

2

There are 2 best solutions below

0
On BEST ANSWER

Thanks all.

I solved the exclusion problem tentatively.

I created the original function which modified tf.gradients function as follows.

def gradients(ys,
          xs,
          grad_start, #***** ←Add arguments that new gradient start op ******
          grad_ys=None,
          name="gradients",
          colocate_gradients_with_ops=False,
          gate_gradients=False,
          aggregation_method=None):
              .
              .
              .
          # The set of 'from_ops'.
          stop_ops = _StopOps(from_ops, pending_count)
          while queue:
            # generate gradient subgraph for op.
            op = queue.popleft()
            with _maybe_colocate_with(op, colocate_gradients_with_ops):
              if loop_state:
                loop_state.EnterGradWhileContext(op, before=True)
              out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method)
              if loop_state:
                loop_state.ExitGradWhileContext(op, before=True)

              #*************************************************************
              # Add 2 line that replace 'out_grads' op to new grad start op
              if grad_start is not None and op == grad_start.op:
                  out_grads = ys
              .
              .
              .

As I confirmed it with Tensorboard, it looks like it is as expected. https://i.stack.imgur.com/vG1e0.png

13
On

What you are asking is not compatible with training your network.

The backpropagation starting point is your loss function (a.k.a the red block). In this loss function, all operations that can be differentiated will be, so that the gradient can "flow" throught the network. Every operations that you "exclude" from this process (using tf.stop_gradient for example) will be a stopping point for your backpropagation and the gradient will not flow to every values that your operation depends on.

Basically what does it mean? If you exclude the "compute block", you can't calculate any gradient for any of your variables.