diff options
7 files changed, 189 insertions, 61 deletions
diff --git a/app/src/androidTest/java/org/rssin/neurons/FeedSorterTest.java b/app/src/androidTest/java/org/rssin/neurons/FeedSorterTest.java index a2d59ee..a6e3581 100755 --- a/app/src/androidTest/java/org/rssin/neurons/FeedSorterTest.java +++ b/app/src/androidTest/java/org/rssin/neurons/FeedSorterTest.java @@ -20,7 +20,7 @@ public class FeedSorterTest extends TestCase { Assert.assertTrue(true);
}
- public void testSortItems() throws Exception {
+ public void testSortItemsByCategory() throws Exception {
List<String> gameList = new ArrayList<>();
gameList.add("Games");
List<String> sportList = new ArrayList<>();
@@ -39,28 +39,90 @@ public class FeedSorterTest extends TestCase { FeedSorter s = new FeedSorter();
//I like games & I hate sports
- for(int i = 0; i < 100; i++)
- {
- for(FeedItem item : likedItems)
- {
- s.feedback(item, Feedback.Like);
- }
-
- for(FeedItem item : dislikedItems)
- {
- s.feedback(item, Feedback.Dislike);
- }
- }
+ trainNetwork(likedItems, dislikedItems, s);
FeedItem sportsItem = new FeedItem("", new Date(2014, 1, 1, 1, 2, 1), "SPORT ARTICLE", "DESCRIPTION", "", "Randy", sportList, "", "", "");
FeedItem gamesItem = new FeedItem("", new Date(2014, 1, 1, 1, 1, 1), "GAME ARTICLE", "DESCRIPTION", "", "Randy", gameList, "", "", "");
+ testSortingOrder(s, sportsItem, gamesItem);
+ }
+
+ public void testSortItemsByTitle() throws Exception {
+ List<String> emptyList = new ArrayList<>();
+
+ FeedItem[] likedItems = new FeedItem[]
+ {
+ new FeedItem("", new Date(), "Video games are cool", "DESCRIPTION", "", "Randy", emptyList, "", "", ""),
+ new FeedItem("", new Date(), "The new video game", "DESCRIPTION", "", "Camil", emptyList, "", "", ""),
+ new FeedItem("", new Date(), "Best games of 2015", "DESCRIPTION", "", "Jos", emptyList, "", "", ""),
+ };
+
+ FeedItem[] dislikedItems = new FeedItem[]
+ {
+ new FeedItem("", new Date(), "Video of a cat", "DESCRIPTION", "", "Randy", emptyList, "", "", ""),
+ new FeedItem("", new Date(), "It's raining", "DESCRIPTION", "", "Joep", emptyList, "", "", ""),
+ new FeedItem("", new Date(), "Shocking video of a cat in the rain.", "DESCRIPTION", "", "Joep", emptyList, "", "", ""),
+ };
+
+ FeedSorter s = new FeedSorter();
+
+ //I like games & I hate sports
+ trainNetwork(likedItems, dislikedItems, s);
+
+ FeedItem dislikedItem = new FeedItem("", new Date(2014, 1, 1, 1, 2, 1), "Another cool video of a cat in the sun.", "DESCRIPTION", "", "Randy", emptyList, "", "", "");
+ FeedItem likedItem = new FeedItem("", new Date(2014, 1, 1, 1, 1, 1), "Coolest retro games", "DESCRIPTION", "", "Jos", emptyList, "", "", "");
+
+ testSortingOrder(s, dislikedItem, likedItem);
+ }
+
+ public void testSortItemsByAuthor() throws Exception {
+ List<String> emptyList = new ArrayList<>();
+
+ FeedItem[] likedItems = new FeedItem[]
+ {
+ new FeedItem("", new Date(), "Best games of 2015", "DESCRIPTION", "", "Jos", emptyList, "", "", ""),
+ new FeedItem("", new Date(), "It's raining cats and dogs!", "DESCRIPTION", "", "Jos", emptyList, "", "", ""),
+ };
+
+ FeedItem[] dislikedItems = new FeedItem[]
+ {
+ new FeedItem("", new Date(), "Video of a cat", "DESCRIPTION", "", "Randy", emptyList, "", "", ""),
+ new FeedItem("", new Date(), "It's raining", "DESCRIPTION", "", "Joep", emptyList, "", "", ""),
+ new FeedItem("", new Date(), "Shocking video of a cat in the rain.", "DESCRIPTION", "", "Joep", emptyList, "", "", ""),
+ new FeedItem("", new Date(), "Video games are cool", "DESCRIPTION", "", "Randy", emptyList, "", "", ""),
+ new FeedItem("", new Date(), "The new video game", "DESCRIPTION", "", "Camil", emptyList, "", "", ""),
+ };
+
+ FeedSorter s = new FeedSorter();
+
+ //I like games & I hate sports
+ trainNetwork(likedItems, dislikedItems, s);
+
+ FeedItem dislikedItem = new FeedItem("", new Date(2014, 1, 1, 1, 2, 1), "Another cool video of a cat in the sun.", "DESCRIPTION", "", "Randy", emptyList, "", "", "");
+ FeedItem likedItem = new FeedItem("", new Date(2014, 1, 1, 1, 1, 1), "Coolest retro games", "DESCRIPTION", "", "Jos", emptyList, "", "", "");
+
+ testSortingOrder(s, dislikedItem, likedItem);
+ }
+
+ private void testSortingOrder(FeedSorter s, FeedItem dislikedItem, FeedItem likedItem) {
List<FeedItem> testItems = new LinkedList<>();
- testItems.add(sportsItem);
- testItems.add(gamesItem);
+ testItems.add(dislikedItem);
+ testItems.add(likedItem);
List<FeedItem> sortedItems = s.sortItems(testItems);
- Assert.assertEquals(sortedItems.get(0), gamesItem);
- Assert.assertEquals(sortedItems.get(1), sportsItem);
+ Assert.assertEquals(sortedItems.get(0), likedItem);
+ Assert.assertEquals(sortedItems.get(1), dislikedItem);
+ }
+
+ private void trainNetwork(FeedItem[] likedItems, FeedItem[] dislikedItems, FeedSorter s) {
+ for(int i = 0; i < 200; i++) {
+ for (FeedItem item : likedItems) {
+ s.feedback(item, Feedback.Like);
+ }
+
+ for (FeedItem item : dislikedItems) {
+ s.feedback(item, Feedback.Dislike);
+ }
+ }
}
}
\ No newline at end of file diff --git a/app/src/androidTest/java/org/rssin/neurons/NeuralNetworkTest.java b/app/src/androidTest/java/org/rssin/neurons/NeuralNetworkTest.java index b0f6eea..57776f3 100755 --- a/app/src/androidTest/java/org/rssin/neurons/NeuralNetworkTest.java +++ b/app/src/androidTest/java/org/rssin/neurons/NeuralNetworkTest.java @@ -7,12 +7,12 @@ public class NeuralNetworkTest extends TestCase { public void testAnd() throws Exception {
MultiNeuralNetwork nn = new MultiNeuralNetwork(10, 2);
- nn.addInput();
- nn.addInput();
- nn.addInput();
+ nn.addInput();//bias
+ nn.addInput();//inputA
+ nn.addInput();//inputB
//Simple AND
- for (int i = 0; i < 100; i++)
+ for (int i = 0; i < 300; i++)
{
PredictionInterface p1 = nn.computeOutput(new double[] {
1,
@@ -75,7 +75,7 @@ public class NeuralNetworkTest extends TestCase { nn.addInput();
//Simple AND
- for (int i = 0; i < 100; i++)
+ for (int i = 0; i < 300; i++)
{
PredictionInterface p1 = nn.computeOutput(new double[] {
1,
diff --git a/app/src/main/java/org/rssin/neurons/FeedSorter.java b/app/src/main/java/org/rssin/neurons/FeedSorter.java index 157a8f7..28d45c1 100755 --- a/app/src/main/java/org/rssin/neurons/FeedSorter.java +++ b/app/src/main/java/org/rssin/neurons/FeedSorter.java @@ -25,6 +25,7 @@ public class FeedSorter implements Serializable{ private final int MAX_TRAINING_HISTORY = 250;
private final int SECONDS_IN_DAY = 24 * 60 * 60;
+ private final SentenceSplitter splitter = new SentenceSplitter();
private MultiNeuralNetwork nn = new MultiNeuralNetwork(25, 50);
@@ -34,8 +35,8 @@ public class FeedSorter implements Serializable{ private int[] isNthWeekDayInput = new int[7];
private int isMorning, isAfternoon, isEvening, isNight, biasInput;
private Hashtable<String, Integer> categoryInputs = new Hashtable<String, Integer>();
- //private Hashtable<String, Integer> wordInputs = new Hashtable<String, Integer>();
- //private Hashtable<String, Integer> feedSourceInputs = new Hashtable<String, Integer>();
+ private Hashtable<String, Integer> wordInputs = new Hashtable<String, Integer>();
+ private Hashtable<String, Integer> authorInputs = new Hashtable<String, Integer>();
public FeedSorter() {
createNewNetwork();
@@ -60,46 +61,24 @@ public class FeedSorter implements Serializable{ }
private PredictionInterface getPrediction(FeedItem item) {
- //Add new inputs for categories.
- for(String category : item.getCategory())
- {
- category = category.toLowerCase();
- if(!categoryInputs.containsKey(category))
- {
- categoryInputs.put(category, nn.addInput());
- }
- }
+ List<String> words = splitter.splitSentence(item.getTitle());
- double[] inputs = new double[nn.getInputCount()];
+ addNewCategoryInputs(item);
+ addNewTitleWordInputs(words);
+ addNewAuthorInputs(item);
+ double[] inputs = newArrayInitializedToNegativeOne();
inputs[biasInput] = 1;
- //Initialize all inputs to -1 / false
- for(int i = 0; i < inputs.length; i++)
- {
- inputs[i] = -1;
- }
+ Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("UTC"));
//Set month
- Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("UTC"));
- for(int i = 0; i < isNthMonthInput.length; i++)
- {
- if(cal.get(Calendar.MONTH) - cal.getMinimum(Calendar.MONTH) == i)
- {
- inputs[isNthMonthInput[i]] = 1;
- }
- }
+ inputs[isNthMonthInput[cal.get(Calendar.MONTH) - cal.getMinimum(Calendar.MONTH)]] = 1;
//Set weekday
- for(int i = 0; i < isNthWeekDayInput.length; i++)
- {
- if(cal.get(Calendar.DAY_OF_WEEK) - cal.getMinimum(Calendar.DAY_OF_WEEK) == i)
- {
- inputs[isNthMonthInput[i]] = 1;
- }
- }
+ inputs[isNthWeekDayInput[cal.get(Calendar.DAY_OF_WEEK) - cal.getMinimum(Calendar.DAY_OF_WEEK)]] = 1;
- //Set day
+ //Set time
int hourOfDay = cal.get(Calendar.HOUR_OF_DAY);
if(hourOfDay > 6 && hourOfDay < 12)
{
@@ -120,9 +99,56 @@ public class FeedSorter implements Serializable{ inputs[categoryInputs.get(category.toLowerCase())] = 1;
}
+ for(String word : words)
+ {
+ inputs[wordInputs.get(word)] = 1;
+ }
+
+ if(item.getAuthor() != null) {
+ inputs[authorInputs.get(item.getAuthor().toLowerCase())] = 1;
+ }
+
return nn.computeOutput(inputs);
}
+ private double[] newArrayInitializedToNegativeOne() {
+ double[] inputs = new double[nn.getInputCount()];
+ Arrays.fill(inputs, 0, inputs.length, -1);
+ return inputs;
+ }
+
+ private void addNewCategoryInputs(FeedItem item) {
+ for(String category : item.getCategory())
+ {
+ category = category.toLowerCase();
+ if(!categoryInputs.containsKey(category))
+ {
+ categoryInputs.put(category, nn.addInput());
+ }
+ }
+ }
+
+ private void addNewAuthorInputs(FeedItem item)
+ {
+ if(item.getAuthor() != null) {
+ String author = item.getAuthor().toLowerCase();
+ if (!authorInputs.containsKey(author)) {
+ authorInputs.put(author, nn.addInput());
+ }
+ }
+ }
+
+ private void addNewTitleWordInputs(List<String> words) {
+ for(String word : words)
+ {
+ word = word.toLowerCase();
+ if(!wordInputs.containsKey(word))
+ {
+ wordInputs.put(word, nn.addInput());
+ }
+ }
+ }
+
/**
* Provides feedback to the neural network.
* @param item The feeditem.
@@ -167,12 +193,13 @@ public class FeedSorter implements Serializable{ */
public List<FeedItem> sortItems(List<FeedItem> items) {
// Sort list based on something like date + nn.computeOutput() * DAY.
- List<FeedItem> newItems = new ArrayList<FeedItem>(items);
+ final List<FeedItem> newItems = new ArrayList<FeedItem>(items);
final Hashtable<FeedItem, PredictionInterface> predictions = new Hashtable<>();
for(FeedItem feed : newItems)
{
- predictions.put(feed, getPrediction(feed));
+ PredictionInterface prediction = getPrediction(feed);
+ predictions.put(feed, prediction);
}
Collections.sort(newItems, new Comparator<FeedItem>() {
diff --git a/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java b/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java index a2f06eb..68ff390 100755 --- a/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java +++ b/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java @@ -4,6 +4,7 @@ import java.io.Serializable; /**
* Created by Jos on 14-5-2015.
+ * Is used to migitate the problem of neural networks ending up in the wrong local minimum.
*/
class MultiNeuralNetwork implements Serializable{
private static final long serialVersionUID = 0;
diff --git a/app/src/main/java/org/rssin/neurons/NeuralNetwork.java b/app/src/main/java/org/rssin/neurons/NeuralNetwork.java index 620d5bd..7fe003f 100755 --- a/app/src/main/java/org/rssin/neurons/NeuralNetwork.java +++ b/app/src/main/java/org/rssin/neurons/NeuralNetwork.java @@ -26,9 +26,11 @@ class NeuralNetwork implements Serializable{ }
public int addInput() {
+ assert hiddenNodes.length > 0;
+
int result = 0;
- for (int i = 0; i < hiddenNodes.length; i++) {
- result = hiddenNodes[i].addWeight();
+ for (Neuron hiddenNode : hiddenNodes) {
+ result = hiddenNode.addWeight();
}
return result;
@@ -37,6 +39,7 @@ class NeuralNetwork implements Serializable{ public PredictionInterface computeOutput(double[] inputs) {
double[] intermediateValues = new double[outputNode.getWeightCount()];
+ //Output of hidden neurons
for (int neuronNum = 0; neuronNum < hiddenNodes.length; neuronNum++) {
Neuron n = hiddenNodes[neuronNum];
@@ -57,6 +60,7 @@ class NeuralNetwork implements Serializable{ }
void learn(NeuralNetworkPrediction p, double expectedOutput) {
+ //TODO: See if adding momentum helps avoid local minima
double actualOutput = p.getOutput();
double[] intermediateValues = p.getIntermediateValues();
double[] inputs = p.getInputs();
@@ -76,8 +80,13 @@ class NeuralNetwork implements Serializable{ hiddenGradients[i] = hiddenDerivative * outputGradient * outputNode.getWeight(i);
}
+ updateWeights(intermediateValues, inputs, hiddenGradients, outputGradient);
+ }
+
+ private void updateWeights(double[] intermediateValues, double[] inputs, double[] hiddenGradients, double outputGradient) {
+ final double learningRate = 0.2;
+
//Update input => hidden weights.
- final double learningRate = 0.3;
for (int neuronNum = 0; neuronNum < hiddenNodes.length; neuronNum++) {
Neuron n = hiddenNodes[neuronNum];
diff --git a/app/src/main/java/org/rssin/neurons/Neuron.java b/app/src/main/java/org/rssin/neurons/Neuron.java index 724732d..d668ebc 100755 --- a/app/src/main/java/org/rssin/neurons/Neuron.java +++ b/app/src/main/java/org/rssin/neurons/Neuron.java @@ -23,7 +23,9 @@ class Neuron { }
public int addWeight() {
- weights.add(r.nextDouble() * 2 - 1);
+ // Initial values range from -.5 to .5. The exact value does not matter,
+ // as long as they aren't all 0.
+ weights.add(r.nextDouble() - .5);
return weights.size() - 1;
}
diff --git a/app/src/main/java/org/rssin/neurons/SentenceSplitter.java b/app/src/main/java/org/rssin/neurons/SentenceSplitter.java new file mode 100755 index 0000000..fc5f46f --- /dev/null +++ b/app/src/main/java/org/rssin/neurons/SentenceSplitter.java @@ -0,0 +1,27 @@ +package org.rssin.neurons;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+/**
+ * Created by Jos on 21-5-2015.
+ */
+public class SentenceSplitter {
+ public SentenceSplitter()
+ { }
+
+ public List<String> splitSentence(String sentence)
+ {
+ List<String> allMatches = new ArrayList<>();
+ Matcher m = Pattern.compile("[\\w-]+").matcher(sentence);
+
+ while (m.find())
+ {
+ allMatches.add(m.group().toLowerCase());
+ }
+
+ return allMatches;
+ }
+}
|