High Memory Usage with Deeplearning4j during Model Inference

99 Views Asked by At

I'm experiencing high memory usage issues when performing inference with a neural network model using Deeplearning4j. Despite reducing the number of hidden nodes, the memory consumption remains unexpectedly high.

Here are the details of my setup:

  • Deeplearning4j version: deeplearning4j-core:1.0.0-M2
  • Backend: ND4J Backend CPU
  • Operating System: Mac
  • Java version: tried Java 15, Java 21

The model configuration is as follows:

  • Number of input nodes: 50
  • Number of hidden nodes: 420
  • Number of outputs: 7
  • Network depth: 10 layers

The issue occurs during the output method call, where I observe a memory usage of around 16.33 GB, which seems excessive for the model size. Here is the code snippet where the issue is observed:

Object [] rollout (State stato){
        INDArray p;
        int v;
        if (trained){
            INDArray inputData = stato.toINDArray();
            INDArray[] out;
            out = model.output(inputData);
            p = out[0];
            v = out[1].getInt(0);
        } else {
            //si verifica quando i modelli non sono ancora addestrati, scelgo quindi un valore a caso e assegno una probabilità massima a ogni nodo figlio
            Random r = new Random();
            p = Nd4j.ones(Board.N).mul(Integer.MAX_VALUE);
            v = Math.abs(r.nextInt()) % MAXREWWARD;
        }
        Object[] out = {p,v};
        return out;
    }

here is my neural network configuration, i am trying to reproduce a version of alphazero

 DeepLearning (int numInputs, int numOutputs, String name, int M, int  N, int X, int nndepth) throws IOException {
        this.numInputs = numInputs;
        this.numHiddenNodes = M*N*10;
        this.numOutputs = numOutputs;
        DeepLearning.M = M;
        DeepLearning.N = N;
        DeepLearning.X = X;
        this.name = name;
        File file = new File("./model" +  name + M + "." + N + "." + X + "." + ".zip");
        if (file.exists()){
            trained = true;
            this.model = ComputationGraph.load(file , true);
        } else {
            ComputationGraphConfiguration.GraphBuilder graphBuilder = new NeuralNetConfiguration.Builder()
                    .seed(System.currentTimeMillis())
                    .weightInit(WeightInit.RELU)
                    .l2(1e-4)
                    .updater(new Adam(learningRate))
                    .graphBuilder();
            graphBuilder.addInputs("input")
                    .setInputTypes(InputType.feedForward(numInputs));
            String lastLayer = "input";
            for (int i = 0; i < nndepth; i++) {
                graphBuilder.addLayer("torso_" + i,
                        new DenseLayer.Builder()
                                .nIn(i == 0 ? numInputs : numHiddenNodes)
                                .nOut(numHiddenNodes)
                                .activation(Activation.RELU)
                                .build(),
                        lastLayer);
                lastLayer = "torso_" + i;
            }

            graphBuilder.addLayer("policy_dense",
                    new DenseLayer.Builder()
                            .nIn(numHiddenNodes)
                            .nOut(numHiddenNodes)
                            .activation(Activation.RELU)
                            .build(),
                    lastLayer);
            graphBuilder.addLayer("policy_output",
                    new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                            .nIn(numHiddenNodes)
                            .nOut(numOutputs)
                            .activation(Activation.SOFTMAX)
                            .build(),
                    "policy_dense");

            graphBuilder.addLayer("value_dense",
                    new DenseLayer.Builder()
                            .nIn(numHiddenNodes)
                            .nOut(numHiddenNodes)
                            .activation(Activation.RELU)
                            .build(),
                    lastLayer);
            graphBuilder.addLayer("value_output",
                    new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
                            .nIn(numHiddenNodes)
                            .nOut(1)
                            .activation(Activation.IDENTITY)
                            .build(),
                    "value_dense");

            graphBuilder.setOutputs("policy_output", "value_output");

            ComputationGraphConfiguration conf = graphBuilder.build();
            model = new ComputationGraph(conf);
        }
        model.init();
        System.out.println("Numero parametri: " + model.numParams());
        this.myReplay = new ReplayBuffer();
    }

    private String addDenseLayer(ComputationGraphConfiguration.GraphBuilder graphBuilder, String inputLayer, String layerName, int nIn, int nOut, Activation activation) {
        graphBuilder.addLayer(layerName, new DenseLayer.Builder()
                .nIn(nIn)
                .nOut(nOut)
                .activation(activation)
                .build(), inputLayer);
        return layerName;
    }

I have tried the following troubleshooting steps without success:

Ensuring the use of workspaces to manage memory efficiently. Using in-place operations to reduce memory allocation. Reducing batch size during inference. Checking for memory leaks and ensuring objects are being garbage collected. The profiler indicates that most of the computation time is spent in the output call of the ComputationGraph.

Here's the memory profile screenshot: memory usage

0

There are 0 best solutions below