SeLU Activation Function Implementation In GRUCell PyTorch C++

105 Views Asked by At

The actual task is to replace the tanh_() at line#799 with SeLU activation function in new_gate of gru_cell. The following code block is the RNN.cpp file from PyTorch github repo.

template <typename cell_params>
struct GRUCell : Cell<Tensor, cell_params> {
  using hidden_type = Tensor;

  hidden_type operator()(
      const Tensor& input,
      const hidden_type& hidden,
      const cell_params& params,
      bool pre_compute_input = false) const override {
    if (input.is_cuda() || input.is_xpu()) {
      TORCH_CHECK(!pre_compute_input);
      auto igates = params.matmul_ih(input);
      auto hgates = params.matmul_hh(hidden);
      auto result = at::_thnn_fused_gru_cell(
          igates, hgates, hidden, params.b_ih(), params.b_hh());
      // Slice off the workspace argument (it's needed only for AD).
      return std::move(std::get<0>(result));
    }
    const auto chunked_igates = pre_compute_input
        ? input.unsafe_chunk(3, 1)
        : params.linear_ih(input).unsafe_chunk(3, 1);
    auto chunked_hgates = params.linear_hh(hidden).unsafe_chunk(3, 1);
    const auto reset_gate =
        chunked_hgates[0].add_(chunked_igates[0]).sigmoid_();
    const auto input_gate =
        chunked_hgates[1].add_(chunked_igates[1]).sigmoid_();
    **const auto new_gate =
        chunked_igates[2].add(chunked_hgates[2].mul_(reset_gate)).tanh_();**
    return (hidden - new_gate).mul_(input_gate).add_(new_gate);
  }
};

The new_gate is the Tensor. How we can implement a custom function to iterate over the Tensor and apply the Selu activation function on them ??

  • I replaced the tanh_() with selu_() that was present in the build/aten/src/ATen/ops/selu.h folder after building the PyTorch from source code In Develop Mode and Also included the related header files. But on Re-Building it generated an error "Did you mean relu_()".

  • I also tried to implement my own function for selu() but the problem was regarding Tensor datatype.

1

There are 1 best solutions below

2
Kozydot On

PyTorch does not provide an in-place version of the SeLU activation function, i.e., selu_(). The in-place function tanh_() modifies the tensor it's called on, whereas selu() returns a new tensor and leaves the original unchanged.

If you want to apply the SeLU activation function to new_gate, you can simply replace tanh_() with selu(). However, since selu() is not an in-place operation, you should assign the result back to new_gate.

auto new_gate = chunked_igates[2].add(chunked_hgates[2].mul_(reset_gate));
new_gate = new_gate.selu();

new_gate is first computed, then the SeLU activation function is applied to it using selu(), and the result is assigned back to new_gate.

Note that this change could potentially affect the performance of your GRU cell because in-place operations like tanh_() are generally faster as they avoid creating new tensors. However, since PyTorch doesn't provide an in-place version for SeLU, this modification is necessary.

Regarding your attempt to implement a custom SeLU function, you would need to ensure that your function can handle PyTorch's tensor data type. This would involve using PyTorch's tensor operations in your function, which would allow it to operate directly on tensors. However, this is generally not recommended unless you have a specific reason to do so, as PyTorch's built-in functions are optimized for performance and should usually be preferred.