aboutsummaryrefslogtreecommitdiff
path: root/app/src/main/java/org
diff options
context:
space:
mode:
authorSize432015-05-20 11:25:57 +0200
committerSize432015-05-20 11:25:57 +0200
commitb8d21d5be8cf89fdb537c8ba32aa0320fbaabdff (patch)
treeb03ccb7bb6e31b6caed366066f200647c32c2ee0 /app/src/main/java/org
parentNN JUnit tests + fix (diff)
Saving FeedSorter + train() method to train neural network on previously saved feedback
Diffstat (limited to 'app/src/main/java/org')
-rwxr-xr-xapp/src/main/java/org/rssin/neurons/FeedSorter.java133
-rwxr-xr-xapp/src/main/java/org/rssin/neurons/FeedSorterStorage.java44
-rwxr-xr-xapp/src/main/java/org/rssin/neurons/Feedback.java16
-rwxr-xr-xapp/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java1
-rwxr-xr-xapp/src/main/java/org/rssin/neurons/MultiNeuralNetworkPrediction.java10
-rwxr-xr-xapp/src/main/java/org/rssin/neurons/NeuralNetwork.java5
-rwxr-xr-xapp/src/main/java/org/rssin/neurons/NeuralNetworkPrediction.java4
-rwxr-xr-xapp/src/main/java/org/rssin/neurons/Neuron.java1
-rwxr-xr-xapp/src/main/java/org/rssin/neurons/PredictionInterface.java1
-rwxr-xr-xapp/src/main/java/org/rssin/neurons/TrainingCase.java23
10 files changed, 174 insertions, 64 deletions
diff --git a/app/src/main/java/org/rssin/neurons/FeedSorter.java b/app/src/main/java/org/rssin/neurons/FeedSorter.java
index 550d9e5..157a8f7 100755
--- a/app/src/main/java/org/rssin/neurons/FeedSorter.java
+++ b/app/src/main/java/org/rssin/neurons/FeedSorter.java
@@ -5,7 +5,9 @@ import android.gesture.Prediction;
import org.rssin.rss.FeedItem;
import java.io.IOException;
+import java.io.Serializable;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Calendar;
import java.util.Collection;
import java.util.Collections;
@@ -18,39 +20,43 @@ import java.util.TimeZone;
/**
* Created by Jos on 14-5-2015.
*/
-public class FeedSorter {
+public class FeedSorter implements Serializable{
+ private static final long serialVersionUID = 0;
+
+ private final int MAX_TRAINING_HISTORY = 250;
private final int SECONDS_IN_DAY = 24 * 60 * 60;
private MultiNeuralNetwork nn = new MultiNeuralNetwork(25, 50);
+ private List<TrainingCase> trainingCases = new ArrayList<>();
+
private int[] isNthMonthInput = new int[12];
private int[] isNthWeekDayInput = new int[7];
private int isMorning, isAfternoon, isEvening, isNight, biasInput;
private Hashtable<String, Integer> categoryInputs = new Hashtable<String, Integer>();
- private Hashtable<String, Integer> wordInputs = new Hashtable<String, Integer>();
- private Hashtable<String, Integer> feedSourceInputs = new Hashtable<String, Integer>();
+ //private Hashtable<String, Integer> wordInputs = new Hashtable<String, Integer>();
+ //private Hashtable<String, Integer> feedSourceInputs = new Hashtable<String, Integer>();
public FeedSorter() {
- //TODO: Load Neural Network
createNewNetwork();
}
private void createNewNetwork() {
biasInput = nn.addInput();
-// for(int i = 0; i < 12; i++)
-// {
-// isNthMonthInput[i] = nn.addInput();
-// }
-//
-// for(int i = 0; i < 7; i++)
-// {
-// isNthWeekDayInput[i] = nn.addInput();
-// }
-//
-// isMorning = nn.addInput();
-// isAfternoon = nn.addInput();
-// isEvening = nn.addInput();
-// isNight = nn.addInput();
+ for(int i = 0; i < 12; i++)
+ {
+ isNthMonthInput[i] = nn.addInput();
+ }
+
+ for(int i = 0; i < 7; i++)
+ {
+ isNthWeekDayInput[i] = nn.addInput();
+ }
+
+ isMorning = nn.addInput();
+ isAfternoon = nn.addInput();
+ isEvening = nn.addInput();
+ isNight = nn.addInput();
}
private PredictionInterface getPrediction(FeedItem item) {
@@ -75,39 +81,39 @@ public class FeedSorter {
}
//Set month
-// Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("UTC"));
-// for(int i = 0; i < isNthMonthInput.length; i++)
-// {
-// if(cal.get(Calendar.MONTH) - cal.getMinimum(Calendar.MONTH) == i)
-// {
-// inputs[isNthMonthInput[i]] = 1;
-// }
-// }
+ Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("UTC"));
+ for(int i = 0; i < isNthMonthInput.length; i++)
+ {
+ if(cal.get(Calendar.MONTH) - cal.getMinimum(Calendar.MONTH) == i)
+ {
+ inputs[isNthMonthInput[i]] = 1;
+ }
+ }
//Set weekday
-// for(int i = 0; i < isNthWeekDayInput.length; i++)
-// {
-// if(cal.get(Calendar.DAY_OF_WEEK) - cal.getMinimum(Calendar.DAY_OF_WEEK) == i)
-// {
-// inputs[isNthMonthInput[i]] = 1;
-// }
-// }
+ for(int i = 0; i < isNthWeekDayInput.length; i++)
+ {
+ if(cal.get(Calendar.DAY_OF_WEEK) - cal.getMinimum(Calendar.DAY_OF_WEEK) == i)
+ {
+ inputs[isNthMonthInput[i]] = 1;
+ }
+ }
//Set day
-// int hourOfDay = cal.get(Calendar.HOUR_OF_DAY);
-// if(hourOfDay > 6 && hourOfDay < 12)
-// {
-// inputs[isMorning] = 1;
-// }else if(hourOfDay >= 12 && hourOfDay <= 6)
-// {
-// inputs[isAfternoon] = 1;
-// }else if(hourOfDay >= 6 && hourOfDay < 23)
-// {
-// inputs[isEvening] = 1;
-// }else if(hourOfDay >= 23 || hourOfDay <= 6)
-// {
-// inputs[isNight] = 1;
-// }
+ int hourOfDay = cal.get(Calendar.HOUR_OF_DAY);
+ if(hourOfDay > 6 && hourOfDay < 12)
+ {
+ inputs[isMorning] = 1;
+ }else if(hourOfDay >= 12 && hourOfDay <= 6)
+ {
+ inputs[isAfternoon] = 1;
+ }else if(hourOfDay >= 6 && hourOfDay < 23)
+ {
+ inputs[isEvening] = 1;
+ }else if(hourOfDay >= 23 || hourOfDay <= 6)
+ {
+ inputs[isNight] = 1;
+ }
for(String category : item.getCategory())
{
@@ -126,16 +132,31 @@ public class FeedSorter {
public void feedback(FeedItem item, Feedback feedback)
{
PredictionInterface prediction = getPrediction(item);
- switch(feedback)
+ prediction.learn(feedback.toExpectedOutput());
+ trainingCases.add(new TrainingCase(prediction.getInputs(), feedback));
+
+ while(trainingCases.size() > MAX_TRAINING_HISTORY)
+ {
+ trainingCases.remove(0);
+ }
+ }
+
+ /**
+ * Runs an iteration of training, using feedback that was provided previously using FeedSorter.feedback(...).
+ */
+ public void train()
+ {
+ for(TrainingCase t : trainingCases)
{
- case Like:
- prediction.learn(1);
- break;
- case Dislike:
- prediction.learn(-1);
- break;
- default:
- throw new IllegalArgumentException("feedback");
+ double[] inputs = t.getInputs();
+ if(inputs.length < nn.getInputCount())
+ {
+ // Resize array to fit new input size
+ inputs = Arrays.copyOf(inputs, nn.getInputCount());
+ }
+
+ PredictionInterface prediction = nn.computeOutput(inputs);
+ prediction.learn(t.getFeedback().toExpectedOutput());
}
}
diff --git a/app/src/main/java/org/rssin/neurons/FeedSorterStorage.java b/app/src/main/java/org/rssin/neurons/FeedSorterStorage.java
new file mode 100755
index 0000000..1479a3b
--- /dev/null
+++ b/app/src/main/java/org/rssin/neurons/FeedSorterStorage.java
@@ -0,0 +1,44 @@
+package org.rssin.neurons;
+
+import android.content.Context;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+
+/**
+ * Created by Jos on 20-5-2015.
+ */
+public class FeedSorterStorage {
+ private final String file = "ml.dat";
+ public FeedSorterStorage()
+ {}
+
+ /**
+ * Reads a FeedSorter from an internal file.
+ * @param context The context
+ * @return The loaded FeedSorter
+ * @throws IOException
+ * @throws ClassNotFoundException
+ */
+ public FeedSorter loadSorter(Context context) throws IOException, ClassNotFoundException {
+ ObjectInputStream is = new ObjectInputStream(context.openFileInput(file));
+ FeedSorter read = (FeedSorter)is.readObject();
+ is.close();
+ return read;
+ }
+
+ /**
+ * Saves the FeedSorter to an internal file.
+ * @param context The context
+ * @param s The FeedSorter to save
+ * @throws IOException
+ */
+ public void saveSorter(Context context, FeedSorter s) throws IOException {
+ ObjectOutputStream os = new ObjectOutputStream(context.openFileOutput(file, Context.MODE_PRIVATE));
+ os.writeObject(s);
+ os.close();
+ }
+}
diff --git a/app/src/main/java/org/rssin/neurons/Feedback.java b/app/src/main/java/org/rssin/neurons/Feedback.java
index af5f1b3..9652b2b 100755
--- a/app/src/main/java/org/rssin/neurons/Feedback.java
+++ b/app/src/main/java/org/rssin/neurons/Feedback.java
@@ -4,6 +4,18 @@ package org.rssin.neurons;
* Created by Jos on 19-5-2015.
*/
public enum Feedback {
- Like,
- Dislike
+ Like, Dislike;
+
+ double toExpectedOutput()
+ {
+ switch(this)
+ {
+ case Like:
+ return 1;
+ case Dislike:
+ return -1;
+ default:
+ throw new IllegalArgumentException();
+ }
+ }
}
diff --git a/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java b/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java
index 50fa707..09dfc21 100755
--- a/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java
+++ b/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java
@@ -4,6 +4,7 @@ package org.rssin.neurons;
* Created by Jos on 14-5-2015.
*/
class MultiNeuralNetwork {
+ private static final long serialVersionUID = 0;
private NeuralNetwork[] networks;
public MultiNeuralNetwork(int numNetworks, int numHiddenNodes) {
diff --git a/app/src/main/java/org/rssin/neurons/MultiNeuralNetworkPrediction.java b/app/src/main/java/org/rssin/neurons/MultiNeuralNetworkPrediction.java
index 8df99e5..dc85261 100755
--- a/app/src/main/java/org/rssin/neurons/MultiNeuralNetworkPrediction.java
+++ b/app/src/main/java/org/rssin/neurons/MultiNeuralNetworkPrediction.java
@@ -7,6 +7,11 @@ class MultiNeuralNetworkPrediction implements PredictionInterface {
private PredictionInterface[] predictions;
MultiNeuralNetworkPrediction(PredictionInterface[] predictions)
{
+ if(predictions.length <= 0)
+ {
+ throw new IllegalArgumentException("predictions");
+ }
+
this.predictions = predictions;
}
@@ -27,4 +32,9 @@ class MultiNeuralNetworkPrediction implements PredictionInterface {
prediction.learn(expectedOutput);
}
}
+
+ public double[] getInputs()
+ {
+ return predictions[0].getInputs();
+ }
}
diff --git a/app/src/main/java/org/rssin/neurons/NeuralNetwork.java b/app/src/main/java/org/rssin/neurons/NeuralNetwork.java
index f5ee569..3766262 100755
--- a/app/src/main/java/org/rssin/neurons/NeuralNetwork.java
+++ b/app/src/main/java/org/rssin/neurons/NeuralNetwork.java
@@ -4,6 +4,7 @@ package org.rssin.neurons;
* Created by Jos on 14-5-2015.
*/
class NeuralNetwork {
+ private static final long serialVersionUID = 0;
private Neuron[] hiddenNodes;
private Neuron outputNode;
@@ -25,11 +26,7 @@ class NeuralNetwork {
public int addInput() {
int result = 0;
for (int i = 0; i < hiddenNodes.length; i++) {
-<<<<<<< HEAD
result = hiddenNodes[i].addWeight();
-=======
- //result = hiddenNodes[i].AddWeight();
->>>>>>> 7336ef600f6bd472b9aa1d59ad6418ff5c543044
}
return result;
diff --git a/app/src/main/java/org/rssin/neurons/NeuralNetworkPrediction.java b/app/src/main/java/org/rssin/neurons/NeuralNetworkPrediction.java
index 162ba21..9d6fc89 100755
--- a/app/src/main/java/org/rssin/neurons/NeuralNetworkPrediction.java
+++ b/app/src/main/java/org/rssin/neurons/NeuralNetworkPrediction.java
@@ -16,11 +16,11 @@ class NeuralNetworkPrediction implements PredictionInterface {
this.nn = nn;
}
- double[] getInputs() {
+ public double[] getInputs() {
return inputs;
}
- double[] getIntermediateValues() {
+ public double[] getIntermediateValues() {
return intermediateValues;
}
diff --git a/app/src/main/java/org/rssin/neurons/Neuron.java b/app/src/main/java/org/rssin/neurons/Neuron.java
index 5df9d47..724732d 100755
--- a/app/src/main/java/org/rssin/neurons/Neuron.java
+++ b/app/src/main/java/org/rssin/neurons/Neuron.java
@@ -8,6 +8,7 @@ import java.util.Random;
* Created by Jos on 14-5-2015.
*/
class Neuron {
+ private static final long serialVersionUID = 0;
private static Random r = new Random();
private List<Double> weights = new ArrayList<Double>();
diff --git a/app/src/main/java/org/rssin/neurons/PredictionInterface.java b/app/src/main/java/org/rssin/neurons/PredictionInterface.java
index e130f4d..27a214f 100755
--- a/app/src/main/java/org/rssin/neurons/PredictionInterface.java
+++ b/app/src/main/java/org/rssin/neurons/PredictionInterface.java
@@ -6,4 +6,5 @@ package org.rssin.neurons;
interface PredictionInterface {
public double getOutput();
public void learn(double expectedOutput);
+ public double[] getInputs();
}
diff --git a/app/src/main/java/org/rssin/neurons/TrainingCase.java b/app/src/main/java/org/rssin/neurons/TrainingCase.java
new file mode 100755
index 0000000..e2680f8
--- /dev/null
+++ b/app/src/main/java/org/rssin/neurons/TrainingCase.java
@@ -0,0 +1,23 @@
+package org.rssin.neurons;
+
+/**
+ * Created by Jos on 20-5-2015.
+ */
+class TrainingCase {
+ private double[] inputs;
+ private Feedback feedback;
+
+ public TrainingCase(double[] inputs, Feedback feedback)
+ {
+ this.inputs = inputs;
+ this.feedback = feedback;
+ }
+
+ public double[] getInputs() {
+ return inputs;
+ }
+
+ public Feedback getFeedback() {
+ return feedback;
+ }
+}