Java Back-propagation ANN output values

298 Views Asked by At

I'm trying to write a simple implementation of a back-propagation ANN in Java and I'm getting very odd output. The ANN has an input layer with two nodes (one for each value in input vector), a single hidden layer with 4 nodes (I have experimented with more to no avail) and an output layer with 3 nodes. The three output nodes represent the three possible classifications of the data, with "one-hot" encoding.

Some input to the ANN is formatted as follows:

0.020055 0.40759 2
0.020117 0.14934 3
0.020117 0.25128 3
0.020262 1.6068 1

with the 2 decimals as the 2 inputs, and the integer being the desired classification (1-3).

After training the ANN with every possible data combination in the list, and All of the outputs I get look like this, with the decimals as the output from the three output nodes

0.11237534579646044   0.15172262242962917   0.7906017313009316   2
0.13686775670201043   0.1774606939461421   0.7656339988150088   3
0.13554918638846133   0.1761024282314506   0.766924262279491   3
0.06185317503169881   0.09410559150503017   0.8516964148498476   1

The way I have written it, each line should display a "high" value (close to 0.95 due to the sigmoidal function) and the rest should be "low" (close to 0.05).

Below is the method I have written to calculate the final output values:

public static double [] testANN(double [] input, List<BPNode> hiddenLayer, List <BPNode> outputLayer){
    double [] outInputs = new double[hiddenLayer.size()];
    double [] results = new double[outputLayer.size()];
    for(int i = 0; i<hiddenLayer.size(); i++){
        BPNode node = hiddenLayer.get(i);
        node.inputs = input.clone();
        outInputs[i] = node.getOutput();
    }
    for(int i = 0; i<outputLayer.size(); i++){
        BPNode node = outputLayer.get(i);
        node.inputs = outInputs.clone();
        results[i] = node.getOutput();
    }

    return results;
}

And here is the method to train the ANN with the back-propagation algorithm:

public static void trainANN(double [] input, int desired, List<BPNode> hiddenLayer, List <BPNode> outputLayer){
   double [] d = new double[outputLayer.size()];
   d[0] = desired==1 ? 0.95 :0.05;
   d[1] = desired==2 ? 0.95 :0.05;
   d[2] = desired==3 ? 0.95 :0.05;
   double [] [] weights = new double[outputLayer.size()][hiddenLayer.size()];
   double [] output = testANN(input,hiddenLayer,outputLayer);
   double [] del = new double[outputLayer.size()];
   for(int i = 0; i<outputLayer.size(); i++){
       del[i] = (d[i]-output[i])*output[i]*(1-output[i]);
       for (int j = 0; j<outputLayer.get(i).weights.length; j++){
           weights[i][j] = outputLayer.get(i).weights[j];
           outputLayer.get(i).weights[j]+=0.2*del[i]*outputLayer.get(i).inputs[j];
       }
   }
   for(int i = 0; i<hiddenLayer.size(); i++){
       double hiddenDel = 0.0;
       for(int j = 0; j<outputLayer.size(); j++){
           hiddenDel+=(del[j]*weights[j][i]*hiddenLayer.get(i).getOutput()*(1-hiddenLayer.get(i).getOutput()));
       }
       for(int j = 0; j<hiddenLayer.get(i).weights.length; j++){
           hiddenLayer.get(i).weights[j]+=0.2*hiddenDel*input[j];
       }
   }

}

And finally, here's the Node class I have used to implement the ANN:

public class BPNode {

public double [] inputs = new double[10];
public double [] weights = new double[10];

public BPNode(double [] w){
    weights = w;
}

public double getOutput() {
    double a = 0;
    for(int j = 0; j<inputs.length; j++){
        a += (inputs[j] * weights[j]);
    }
    return sigmoid(a,10.0);
}

private static double sigmoid(double x, double m)
{
    return 1 / (1 + Math.exp(-x*m));
}

}

All of the weights have been initialized to 0.1, and the nodes have been placed in the Array Lists. Thank you so much for your help.

1

There are 1 best solutions below

0
On BEST ANSWER

For anyone wondering how I solved this, I tried varying the weights randomly between 0 and 1. This performed slightly better, but once I tried varying the weights between 1 and -1, the ANN was able to classify 85% of the data correctly.