From b8d21d5be8cf89fdb537c8ba32aa0320fbaabdff Mon Sep 17 00:00:00 2001 From: Size43 Date: Wed, 20 May 2015 11:25:57 +0200 Subject: Saving FeedSorter + train() method to train neural network on previously saved feedback --- .../main/java/org/rssin/neurons/FeedSorter.java | 133 ++++++++++++--------- .../java/org/rssin/neurons/FeedSorterStorage.java | 44 +++++++ app/src/main/java/org/rssin/neurons/Feedback.java | 16 ++- .../java/org/rssin/neurons/MultiNeuralNetwork.java | 1 + .../neurons/MultiNeuralNetworkPrediction.java | 10 ++ .../main/java/org/rssin/neurons/NeuralNetwork.java | 5 +- .../org/rssin/neurons/NeuralNetworkPrediction.java | 4 +- app/src/main/java/org/rssin/neurons/Neuron.java | 1 + .../org/rssin/neurons/PredictionInterface.java | 1 + .../main/java/org/rssin/neurons/TrainingCase.java | 23 ++++ 10 files changed, 174 insertions(+), 64 deletions(-) create mode 100755 app/src/main/java/org/rssin/neurons/FeedSorterStorage.java create mode 100755 app/src/main/java/org/rssin/neurons/TrainingCase.java (limited to 'app/src/main/java') diff --git a/app/src/main/java/org/rssin/neurons/FeedSorter.java b/app/src/main/java/org/rssin/neurons/FeedSorter.java index 550d9e5..157a8f7 100755 --- a/app/src/main/java/org/rssin/neurons/FeedSorter.java +++ b/app/src/main/java/org/rssin/neurons/FeedSorter.java @@ -5,7 +5,9 @@ import android.gesture.Prediction; import org.rssin.rss.FeedItem; import java.io.IOException; +import java.io.Serializable; import java.util.ArrayList; +import java.util.Arrays; import java.util.Calendar; import java.util.Collection; import java.util.Collections; @@ -18,39 +20,43 @@ import java.util.TimeZone; /** * Created by Jos on 14-5-2015. */ -public class FeedSorter { +public class FeedSorter implements Serializable{ + private static final long serialVersionUID = 0; + + private final int MAX_TRAINING_HISTORY = 250; private final int SECONDS_IN_DAY = 24 * 60 * 60; private MultiNeuralNetwork nn = new MultiNeuralNetwork(25, 50); + private List trainingCases = new ArrayList<>(); + private int[] isNthMonthInput = new int[12]; private int[] isNthWeekDayInput = new int[7]; private int isMorning, isAfternoon, isEvening, isNight, biasInput; private Hashtable categoryInputs = new Hashtable(); - private Hashtable wordInputs = new Hashtable(); - private Hashtable feedSourceInputs = new Hashtable(); + //private Hashtable wordInputs = new Hashtable(); + //private Hashtable feedSourceInputs = new Hashtable(); public FeedSorter() { - //TODO: Load Neural Network createNewNetwork(); } private void createNewNetwork() { biasInput = nn.addInput(); -// for(int i = 0; i < 12; i++) -// { -// isNthMonthInput[i] = nn.addInput(); -// } -// -// for(int i = 0; i < 7; i++) -// { -// isNthWeekDayInput[i] = nn.addInput(); -// } -// -// isMorning = nn.addInput(); -// isAfternoon = nn.addInput(); -// isEvening = nn.addInput(); -// isNight = nn.addInput(); + for(int i = 0; i < 12; i++) + { + isNthMonthInput[i] = nn.addInput(); + } + + for(int i = 0; i < 7; i++) + { + isNthWeekDayInput[i] = nn.addInput(); + } + + isMorning = nn.addInput(); + isAfternoon = nn.addInput(); + isEvening = nn.addInput(); + isNight = nn.addInput(); } private PredictionInterface getPrediction(FeedItem item) { @@ -75,39 +81,39 @@ public class FeedSorter { } //Set month -// Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("UTC")); -// for(int i = 0; i < isNthMonthInput.length; i++) -// { -// if(cal.get(Calendar.MONTH) - cal.getMinimum(Calendar.MONTH) == i) -// { -// inputs[isNthMonthInput[i]] = 1; -// } -// } + Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("UTC")); + for(int i = 0; i < isNthMonthInput.length; i++) + { + if(cal.get(Calendar.MONTH) - cal.getMinimum(Calendar.MONTH) == i) + { + inputs[isNthMonthInput[i]] = 1; + } + } //Set weekday -// for(int i = 0; i < isNthWeekDayInput.length; i++) -// { -// if(cal.get(Calendar.DAY_OF_WEEK) - cal.getMinimum(Calendar.DAY_OF_WEEK) == i) -// { -// inputs[isNthMonthInput[i]] = 1; -// } -// } + for(int i = 0; i < isNthWeekDayInput.length; i++) + { + if(cal.get(Calendar.DAY_OF_WEEK) - cal.getMinimum(Calendar.DAY_OF_WEEK) == i) + { + inputs[isNthMonthInput[i]] = 1; + } + } //Set day -// int hourOfDay = cal.get(Calendar.HOUR_OF_DAY); -// if(hourOfDay > 6 && hourOfDay < 12) -// { -// inputs[isMorning] = 1; -// }else if(hourOfDay >= 12 && hourOfDay <= 6) -// { -// inputs[isAfternoon] = 1; -// }else if(hourOfDay >= 6 && hourOfDay < 23) -// { -// inputs[isEvening] = 1; -// }else if(hourOfDay >= 23 || hourOfDay <= 6) -// { -// inputs[isNight] = 1; -// } + int hourOfDay = cal.get(Calendar.HOUR_OF_DAY); + if(hourOfDay > 6 && hourOfDay < 12) + { + inputs[isMorning] = 1; + }else if(hourOfDay >= 12 && hourOfDay <= 6) + { + inputs[isAfternoon] = 1; + }else if(hourOfDay >= 6 && hourOfDay < 23) + { + inputs[isEvening] = 1; + }else if(hourOfDay >= 23 || hourOfDay <= 6) + { + inputs[isNight] = 1; + } for(String category : item.getCategory()) { @@ -126,16 +132,31 @@ public class FeedSorter { public void feedback(FeedItem item, Feedback feedback) { PredictionInterface prediction = getPrediction(item); - switch(feedback) + prediction.learn(feedback.toExpectedOutput()); + trainingCases.add(new TrainingCase(prediction.getInputs(), feedback)); + + while(trainingCases.size() > MAX_TRAINING_HISTORY) + { + trainingCases.remove(0); + } + } + + /** + * Runs an iteration of training, using feedback that was provided previously using FeedSorter.feedback(...). + */ + public void train() + { + for(TrainingCase t : trainingCases) { - case Like: - prediction.learn(1); - break; - case Dislike: - prediction.learn(-1); - break; - default: - throw new IllegalArgumentException("feedback"); + double[] inputs = t.getInputs(); + if(inputs.length < nn.getInputCount()) + { + // Resize array to fit new input size + inputs = Arrays.copyOf(inputs, nn.getInputCount()); + } + + PredictionInterface prediction = nn.computeOutput(inputs); + prediction.learn(t.getFeedback().toExpectedOutput()); } } diff --git a/app/src/main/java/org/rssin/neurons/FeedSorterStorage.java b/app/src/main/java/org/rssin/neurons/FeedSorterStorage.java new file mode 100755 index 0000000..1479a3b --- /dev/null +++ b/app/src/main/java/org/rssin/neurons/FeedSorterStorage.java @@ -0,0 +1,44 @@ +package org.rssin.neurons; + +import android.content.Context; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; + +/** + * Created by Jos on 20-5-2015. + */ +public class FeedSorterStorage { + private final String file = "ml.dat"; + public FeedSorterStorage() + {} + + /** + * Reads a FeedSorter from an internal file. + * @param context The context + * @return The loaded FeedSorter + * @throws IOException + * @throws ClassNotFoundException + */ + public FeedSorter loadSorter(Context context) throws IOException, ClassNotFoundException { + ObjectInputStream is = new ObjectInputStream(context.openFileInput(file)); + FeedSorter read = (FeedSorter)is.readObject(); + is.close(); + return read; + } + + /** + * Saves the FeedSorter to an internal file. + * @param context The context + * @param s The FeedSorter to save + * @throws IOException + */ + public void saveSorter(Context context, FeedSorter s) throws IOException { + ObjectOutputStream os = new ObjectOutputStream(context.openFileOutput(file, Context.MODE_PRIVATE)); + os.writeObject(s); + os.close(); + } +} diff --git a/app/src/main/java/org/rssin/neurons/Feedback.java b/app/src/main/java/org/rssin/neurons/Feedback.java index af5f1b3..9652b2b 100755 --- a/app/src/main/java/org/rssin/neurons/Feedback.java +++ b/app/src/main/java/org/rssin/neurons/Feedback.java @@ -4,6 +4,18 @@ package org.rssin.neurons; * Created by Jos on 19-5-2015. */ public enum Feedback { - Like, - Dislike + Like, Dislike; + + double toExpectedOutput() + { + switch(this) + { + case Like: + return 1; + case Dislike: + return -1; + default: + throw new IllegalArgumentException(); + } + } } diff --git a/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java b/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java index 50fa707..09dfc21 100755 --- a/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java +++ b/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java @@ -4,6 +4,7 @@ package org.rssin.neurons; * Created by Jos on 14-5-2015. */ class MultiNeuralNetwork { + private static final long serialVersionUID = 0; private NeuralNetwork[] networks; public MultiNeuralNetwork(int numNetworks, int numHiddenNodes) { diff --git a/app/src/main/java/org/rssin/neurons/MultiNeuralNetworkPrediction.java b/app/src/main/java/org/rssin/neurons/MultiNeuralNetworkPrediction.java index 8df99e5..dc85261 100755 --- a/app/src/main/java/org/rssin/neurons/MultiNeuralNetworkPrediction.java +++ b/app/src/main/java/org/rssin/neurons/MultiNeuralNetworkPrediction.java @@ -7,6 +7,11 @@ class MultiNeuralNetworkPrediction implements PredictionInterface { private PredictionInterface[] predictions; MultiNeuralNetworkPrediction(PredictionInterface[] predictions) { + if(predictions.length <= 0) + { + throw new IllegalArgumentException("predictions"); + } + this.predictions = predictions; } @@ -27,4 +32,9 @@ class MultiNeuralNetworkPrediction implements PredictionInterface { prediction.learn(expectedOutput); } } + + public double[] getInputs() + { + return predictions[0].getInputs(); + } } diff --git a/app/src/main/java/org/rssin/neurons/NeuralNetwork.java b/app/src/main/java/org/rssin/neurons/NeuralNetwork.java index f5ee569..3766262 100755 --- a/app/src/main/java/org/rssin/neurons/NeuralNetwork.java +++ b/app/src/main/java/org/rssin/neurons/NeuralNetwork.java @@ -4,6 +4,7 @@ package org.rssin.neurons; * Created by Jos on 14-5-2015. */ class NeuralNetwork { + private static final long serialVersionUID = 0; private Neuron[] hiddenNodes; private Neuron outputNode; @@ -25,11 +26,7 @@ class NeuralNetwork { public int addInput() { int result = 0; for (int i = 0; i < hiddenNodes.length; i++) { -<<<<<<< HEAD result = hiddenNodes[i].addWeight(); -======= - //result = hiddenNodes[i].AddWeight(); ->>>>>>> 7336ef600f6bd472b9aa1d59ad6418ff5c543044 } return result; diff --git a/app/src/main/java/org/rssin/neurons/NeuralNetworkPrediction.java b/app/src/main/java/org/rssin/neurons/NeuralNetworkPrediction.java index 162ba21..9d6fc89 100755 --- a/app/src/main/java/org/rssin/neurons/NeuralNetworkPrediction.java +++ b/app/src/main/java/org/rssin/neurons/NeuralNetworkPrediction.java @@ -16,11 +16,11 @@ class NeuralNetworkPrediction implements PredictionInterface { this.nn = nn; } - double[] getInputs() { + public double[] getInputs() { return inputs; } - double[] getIntermediateValues() { + public double[] getIntermediateValues() { return intermediateValues; } diff --git a/app/src/main/java/org/rssin/neurons/Neuron.java b/app/src/main/java/org/rssin/neurons/Neuron.java index 5df9d47..724732d 100755 --- a/app/src/main/java/org/rssin/neurons/Neuron.java +++ b/app/src/main/java/org/rssin/neurons/Neuron.java @@ -8,6 +8,7 @@ import java.util.Random; * Created by Jos on 14-5-2015. */ class Neuron { + private static final long serialVersionUID = 0; private static Random r = new Random(); private List weights = new ArrayList(); diff --git a/app/src/main/java/org/rssin/neurons/PredictionInterface.java b/app/src/main/java/org/rssin/neurons/PredictionInterface.java index e130f4d..27a214f 100755 --- a/app/src/main/java/org/rssin/neurons/PredictionInterface.java +++ b/app/src/main/java/org/rssin/neurons/PredictionInterface.java @@ -6,4 +6,5 @@ package org.rssin.neurons; interface PredictionInterface { public double getOutput(); public void learn(double expectedOutput); + public double[] getInputs(); } diff --git a/app/src/main/java/org/rssin/neurons/TrainingCase.java b/app/src/main/java/org/rssin/neurons/TrainingCase.java new file mode 100755 index 0000000..e2680f8 --- /dev/null +++ b/app/src/main/java/org/rssin/neurons/TrainingCase.java @@ -0,0 +1,23 @@ +package org.rssin.neurons; + +/** + * Created by Jos on 20-5-2015. + */ +class TrainingCase { + private double[] inputs; + private Feedback feedback; + + public TrainingCase(double[] inputs, Feedback feedback) + { + this.inputs = inputs; + this.feedback = feedback; + } + + public double[] getInputs() { + return inputs; + } + + public Feedback getFeedback() { + return feedback; + } +} -- cgit v1.2.3