Seq2Seq Model (DL4J) Making Absurd Predictions

443 Views Asked by At

I am trying to implement a Seq2Seq Predictor Model in DL4J. What I ultimately want is to use a time series of INPUT_SIZE data points to predict the following time series of OUTPUT_SIZE data points using this type of model. Each data point has numFeatures features. Now, DL4J has some example code explaining how to implement a very basic Seq2Seq model. I have made some progress extending their example to my own needs; the below model compiles, but the predictions it is making are nonsensical.

ComputationGraphConfiguration configuration = new 
NeuralNetConfiguration.Builder()
    .weightInit(WeightInit.XAVIER)
    .updater(new Adam(0.25))
    .seed(42)
    .graphBuilder()
    .addInputs("in_data", "last_in")
    .setInputTypes(InputType.recurrent(numFeatures), InputType.recurrent(numFeatures))
    //The inputs to the encoder will have size = minibatch x featuresize x timesteps
    //Note that the network only knows of the feature vector size. It does not know how many time steps unless it sees an instance of the data
    .addLayer("encoder", new LSTM.Builder().nIn(numFeatures).nOut(hiddenLayerWidth).activation(Activation.LEAKYRELU).build(), "in_data")
    //Create a vertex indicating the very last time step of the encoder layer needs to be directed to other places in the comp graph
    .addVertex("lastTimeStep", new LastTimeStepVertex("in_data"), "encoder")
    //Create a vertex that allows the duplication of 2d input to a 3d input
    //In this case the last time step of the encoder layer (viz. 2d) is duplicated to the length of the timeseries "sumOut" which is an input to the comp graph
    //Refer to the javadoc for more detail
    .addVertex("duplicateTimeStep", new DuplicateToTimeSeriesVertex("last_in"), "lastTimeStep")
    //The inputs to the decoder will have size = size of output of last timestep of encoder (numHiddenNodes) + size of the other input to the comp graph,sumOut (feature vector size)
    .addLayer("decoder", new LSTM.Builder().nIn(numFeatures + hiddenLayerWidth).nOut(hiddenLayerWidth).activation(Activation.LEAKYRELU).build(), "last_in","duplicateTimeStep")
    .addLayer("output", new RnnOutputLayer.Builder().nIn(hiddenLayerWidth).nOut(numFeatures).activation(Activation.LEAKYRELU).lossFunction(LossFunctions.LossFunction.MSE).build(), "decoder")
    .setOutputs("output")
    .build();

ComputationGraph net = new ComputationGraph(configuration);
net.init();
net.setListeners(new ScoreIterationListener(1));

The way I structure my input/labeled data is that I have the input data split between the first INPUT_SIZE - 1 time series observations (corresponding to the in_data input in the ComputationGraph) and then the last times series observation (corresponding to the lastIn input). The labels are a single time step in the future; to make predictions, I simply call net.output() OUTPUT_SIZE times to make get all the predictions I want. To better see this, this is how I initialize my input/labels:

INDArray[] input = new INDArray[] {Nd4j.zeros(batchSize, numFeatures, INPUT_SIZE - 1), Nd4j.zeros(batchSize, numFeatures, 1)};
INDArray[] labels = new INDArray[] {Nd4j.zeros(batchSize, numFeatures, 1)};

I believe my error is coming from an error in the architecture of my Computation Graph and not with how I prepare my data/make predictions/something else, as I have done other mini-projects with simpler architectures and have had no problems.

My data is normalized to have mean 0 and std. deviation of 1. Thus, most entries should be around 0, however, most of the predictions I get are values with absolute value much greater than zero (on the order of 10s-100s). This clearly is not correct. I have been working on this for some time and have been unable to find the issue; any suggestions for how to fix this would be much appreciated.

Other resources I have used: The example Seq2Seq Model can be found here, starting at line 88. The ComputationGraph documentation can be found here; I have read this extensively to see if I could find an error to no avail.

0

There are 0 best solutions below