Crash of the software when forwarding a libtorch model

75 Views Asked by At

My software is crashing when I try to forward my libtorch model and I don't understand why. I compiled with Visual Studio 2019 (v142), ISO C++20, libtorch 2.0.0+cu118

Here is the code:

struct CustomModel : torch::nn::Module {

    // Déclaration des couches
    torch::nn::LSTM lstm{ nullptr };
    torch::nn::Linear fc1{ nullptr };
    torch::nn::Linear output_layer{ nullptr };

    // Constructeur
    CustomModel(int64_t input_channels, int64_t num_actions)
    {
        //lstm(register_module("lstm", torch::nn::LSTM(torch::nn::LSTMOptions(60, 50).input_size(60).hidden_size(50)))
        //)
        lstm = register_module("lstm",
            torch::nn::LSTM(torch::nn::LSTMOptions(60, 50).num_layers(50)));
        //fc1(register_module("fc1", torch::nn::Linear(11, 64)) //11
        //),  // 11 entrées autres que la séquence
        //output_layer(register_module("output_layer", torch::nn::Linear(50 + 64, num_actions))) {}
        fc1 = register_module("fc1", torch::nn::Linear(13, 64));
        output_layer = register_module("output_layer", torch::nn::Linear(50 + 64, num_actions));
    }

    // Fonction forward
    torch::Tensor forward(torch::Tensor sequence_input) {

        // MODIFIER LES SEQUENCES ET AJOUTER UN RESEAU RELU + CONCATENER

        ::MessageBox(NULL, L"1", L"forward", MB_OK);

        torch::Tensor sequence_lstm = sequence_input.narrow(0, 0, 60); // du 5ieme au 64ieme éléments de sequence_input

        ::MessageBox(NULL, L"2", L"forward", MB_OK);

        torch::Tensor sequence_other = sequence_input.narrow(0, 60, sequence_input.size(0) - 60); // le reste des éléments de sequence_input

        ::MessageBox(NULL, L"3", L"forward", MB_OK);

        wchar_t str1120[100];
        wsprintf(str1120, L"%d", sequence_other.sizes()[0]);
        ::MessageBox(NULL, str1120, L"sequence_input.sizes()[0]", MB_OK);

        wchar_t str1121[100];
        wsprintf(str1121, L"%d", sequence_other.sizes()[1]);
        ::MessageBox(NULL, str1121, L"sequence_input.sizes()[1]", MB_OK);

        wchar_t str1122[100];
        wsprintf(str1122, L"%d", sequence_other.sizes()[2]);
        ::MessageBox(NULL, str1122, L"sequence_input.sizes()[2]", MB_OK);

        wchar_t str1123[100];
        wsprintf(str1123, L"%d", sequence_other.sizes()[3]);
        ::MessageBox(NULL, str1123, L"sequence_input.sizes()[3]", MB_OK);

        // Passe avant pour la séquence
        torch::Tensor lstm_output;

        //std::tie(lstm_output, std::ignore) = lstm(sequence_lstm.squeeze(2)); // Squeeze to remove the size-1 dimension

        ::MessageBox(NULL, L"4", L"forward", MB_OK);

        // Passe avant pour les autres entrées
        //torch::Tensor fc1_output = torch::relu(fc1(sequence_other.squeeze(2)));

        torch::Tensor fc1_output = torch::relu(fc1(sequence_other.squeeze(0)));

        ::MessageBox(NULL, L"5", L"forward", MB_OK);

        // Concaténation des sorties
        torch::Tensor combined_output = torch::cat({ lstm_output, fc1_output }, 1);

        ::MessageBox(NULL, L"6", L"forward", MB_OK);

        // Passe avant finale
        torch::Tensor output = output_layer(combined_output);

        return output;
    }

    torch::Tensor act(torch::Tensor sequence_input) {
        torch::Tensor q_value = forward(sequence_input);
        torch::Tensor action = std::get<1>(q_value.max(0));
        return action;
    }
};

with:

torch::Tensor state_tensor = get_tensor_observation(state);

torch::Tensor action_tensor = network.act(state_tensor);

and:

torch::Tensor Trainer::get_tensor_observation(std::vector<unsigned char> state) {
    std::vector<int64_t > state_int;
    state_int.reserve(state.size());

    for (int i = 0; i < state.size(); i++) {
        state_int.push_back(int64_t(state[i]));
    }

    // MODIFIER FONCTION FROM_BLOB

    torch::Tensor state_tensor = torch::from_blob(state_int.data(), { 73 });
    return state_tensor;
}

I commented some lines to simplify the model. Though, it still doesn't work. Does anyone have an idea how to solve this issue?

Thank you in advance,

Laurick

0

There are 0 best solutions below