From a0c19286783e040d0ee87a70b8257f99474b8714 Mon Sep 17 00:00:00 2001 From: Size43 Date: Thu, 21 May 2015 15:28:58 +0200 Subject: Reduced LOC + FeedSorter sorts by category, author, title and date/time + reduced learning rate from 0.3 to 0.2 --- .../java/org/rssin/neurons/FeedSorterTest.java | 96 +++++++++++++++++---- .../java/org/rssin/neurons/NeuralNetworkTest.java | 10 +-- .../main/java/org/rssin/neurons/FeedSorter.java | 97 ++++++++++++++-------- .../java/org/rssin/neurons/MultiNeuralNetwork.java | 1 + .../main/java/org/rssin/neurons/NeuralNetwork.java | 15 +++- app/src/main/java/org/rssin/neurons/Neuron.java | 4 +- .../java/org/rssin/neurons/SentenceSplitter.java | 27 ++++++ 7 files changed, 189 insertions(+), 61 deletions(-) create mode 100755 app/src/main/java/org/rssin/neurons/SentenceSplitter.java diff --git a/app/src/androidTest/java/org/rssin/neurons/FeedSorterTest.java b/app/src/androidTest/java/org/rssin/neurons/FeedSorterTest.java index a2d59ee..a6e3581 100755 --- a/app/src/androidTest/java/org/rssin/neurons/FeedSorterTest.java +++ b/app/src/androidTest/java/org/rssin/neurons/FeedSorterTest.java @@ -20,7 +20,7 @@ public class FeedSorterTest extends TestCase { Assert.assertTrue(true); } - public void testSortItems() throws Exception { + public void testSortItemsByCategory() throws Exception { List gameList = new ArrayList<>(); gameList.add("Games"); List sportList = new ArrayList<>(); @@ -39,28 +39,90 @@ public class FeedSorterTest extends TestCase { FeedSorter s = new FeedSorter(); //I like games & I hate sports - for(int i = 0; i < 100; i++) - { - for(FeedItem item : likedItems) - { - s.feedback(item, Feedback.Like); - } - - for(FeedItem item : dislikedItems) - { - s.feedback(item, Feedback.Dislike); - } - } + trainNetwork(likedItems, dislikedItems, s); FeedItem sportsItem = new FeedItem("", new Date(2014, 1, 1, 1, 2, 1), "SPORT ARTICLE", "DESCRIPTION", "", "Randy", sportList, "", "", ""); FeedItem gamesItem = new FeedItem("", new Date(2014, 1, 1, 1, 1, 1), "GAME ARTICLE", "DESCRIPTION", "", "Randy", gameList, "", "", ""); + testSortingOrder(s, sportsItem, gamesItem); + } + + public void testSortItemsByTitle() throws Exception { + List emptyList = new ArrayList<>(); + + FeedItem[] likedItems = new FeedItem[] + { + new FeedItem("", new Date(), "Video games are cool", "DESCRIPTION", "", "Randy", emptyList, "", "", ""), + new FeedItem("", new Date(), "The new video game", "DESCRIPTION", "", "Camil", emptyList, "", "", ""), + new FeedItem("", new Date(), "Best games of 2015", "DESCRIPTION", "", "Jos", emptyList, "", "", ""), + }; + + FeedItem[] dislikedItems = new FeedItem[] + { + new FeedItem("", new Date(), "Video of a cat", "DESCRIPTION", "", "Randy", emptyList, "", "", ""), + new FeedItem("", new Date(), "It's raining", "DESCRIPTION", "", "Joep", emptyList, "", "", ""), + new FeedItem("", new Date(), "Shocking video of a cat in the rain.", "DESCRIPTION", "", "Joep", emptyList, "", "", ""), + }; + + FeedSorter s = new FeedSorter(); + + //I like games & I hate sports + trainNetwork(likedItems, dislikedItems, s); + + FeedItem dislikedItem = new FeedItem("", new Date(2014, 1, 1, 1, 2, 1), "Another cool video of a cat in the sun.", "DESCRIPTION", "", "Randy", emptyList, "", "", ""); + FeedItem likedItem = new FeedItem("", new Date(2014, 1, 1, 1, 1, 1), "Coolest retro games", "DESCRIPTION", "", "Jos", emptyList, "", "", ""); + + testSortingOrder(s, dislikedItem, likedItem); + } + + public void testSortItemsByAuthor() throws Exception { + List emptyList = new ArrayList<>(); + + FeedItem[] likedItems = new FeedItem[] + { + new FeedItem("", new Date(), "Best games of 2015", "DESCRIPTION", "", "Jos", emptyList, "", "", ""), + new FeedItem("", new Date(), "It's raining cats and dogs!", "DESCRIPTION", "", "Jos", emptyList, "", "", ""), + }; + + FeedItem[] dislikedItems = new FeedItem[] + { + new FeedItem("", new Date(), "Video of a cat", "DESCRIPTION", "", "Randy", emptyList, "", "", ""), + new FeedItem("", new Date(), "It's raining", "DESCRIPTION", "", "Joep", emptyList, "", "", ""), + new FeedItem("", new Date(), "Shocking video of a cat in the rain.", "DESCRIPTION", "", "Joep", emptyList, "", "", ""), + new FeedItem("", new Date(), "Video games are cool", "DESCRIPTION", "", "Randy", emptyList, "", "", ""), + new FeedItem("", new Date(), "The new video game", "DESCRIPTION", "", "Camil", emptyList, "", "", ""), + }; + + FeedSorter s = new FeedSorter(); + + //I like games & I hate sports + trainNetwork(likedItems, dislikedItems, s); + + FeedItem dislikedItem = new FeedItem("", new Date(2014, 1, 1, 1, 2, 1), "Another cool video of a cat in the sun.", "DESCRIPTION", "", "Randy", emptyList, "", "", ""); + FeedItem likedItem = new FeedItem("", new Date(2014, 1, 1, 1, 1, 1), "Coolest retro games", "DESCRIPTION", "", "Jos", emptyList, "", "", ""); + + testSortingOrder(s, dislikedItem, likedItem); + } + + private void testSortingOrder(FeedSorter s, FeedItem dislikedItem, FeedItem likedItem) { List testItems = new LinkedList<>(); - testItems.add(sportsItem); - testItems.add(gamesItem); + testItems.add(dislikedItem); + testItems.add(likedItem); List sortedItems = s.sortItems(testItems); - Assert.assertEquals(sortedItems.get(0), gamesItem); - Assert.assertEquals(sortedItems.get(1), sportsItem); + Assert.assertEquals(sortedItems.get(0), likedItem); + Assert.assertEquals(sortedItems.get(1), dislikedItem); + } + + private void trainNetwork(FeedItem[] likedItems, FeedItem[] dislikedItems, FeedSorter s) { + for(int i = 0; i < 200; i++) { + for (FeedItem item : likedItems) { + s.feedback(item, Feedback.Like); + } + + for (FeedItem item : dislikedItems) { + s.feedback(item, Feedback.Dislike); + } + } } } \ No newline at end of file diff --git a/app/src/androidTest/java/org/rssin/neurons/NeuralNetworkTest.java b/app/src/androidTest/java/org/rssin/neurons/NeuralNetworkTest.java index b0f6eea..57776f3 100755 --- a/app/src/androidTest/java/org/rssin/neurons/NeuralNetworkTest.java +++ b/app/src/androidTest/java/org/rssin/neurons/NeuralNetworkTest.java @@ -7,12 +7,12 @@ public class NeuralNetworkTest extends TestCase { public void testAnd() throws Exception { MultiNeuralNetwork nn = new MultiNeuralNetwork(10, 2); - nn.addInput(); - nn.addInput(); - nn.addInput(); + nn.addInput();//bias + nn.addInput();//inputA + nn.addInput();//inputB //Simple AND - for (int i = 0; i < 100; i++) + for (int i = 0; i < 300; i++) { PredictionInterface p1 = nn.computeOutput(new double[] { 1, @@ -75,7 +75,7 @@ public class NeuralNetworkTest extends TestCase { nn.addInput(); //Simple AND - for (int i = 0; i < 100; i++) + for (int i = 0; i < 300; i++) { PredictionInterface p1 = nn.computeOutput(new double[] { 1, diff --git a/app/src/main/java/org/rssin/neurons/FeedSorter.java b/app/src/main/java/org/rssin/neurons/FeedSorter.java index 157a8f7..28d45c1 100755 --- a/app/src/main/java/org/rssin/neurons/FeedSorter.java +++ b/app/src/main/java/org/rssin/neurons/FeedSorter.java @@ -25,6 +25,7 @@ public class FeedSorter implements Serializable{ private final int MAX_TRAINING_HISTORY = 250; private final int SECONDS_IN_DAY = 24 * 60 * 60; + private final SentenceSplitter splitter = new SentenceSplitter(); private MultiNeuralNetwork nn = new MultiNeuralNetwork(25, 50); @@ -34,8 +35,8 @@ public class FeedSorter implements Serializable{ private int[] isNthWeekDayInput = new int[7]; private int isMorning, isAfternoon, isEvening, isNight, biasInput; private Hashtable categoryInputs = new Hashtable(); - //private Hashtable wordInputs = new Hashtable(); - //private Hashtable feedSourceInputs = new Hashtable(); + private Hashtable wordInputs = new Hashtable(); + private Hashtable authorInputs = new Hashtable(); public FeedSorter() { createNewNetwork(); @@ -60,46 +61,24 @@ public class FeedSorter implements Serializable{ } private PredictionInterface getPrediction(FeedItem item) { - //Add new inputs for categories. - for(String category : item.getCategory()) - { - category = category.toLowerCase(); - if(!categoryInputs.containsKey(category)) - { - categoryInputs.put(category, nn.addInput()); - } - } + List words = splitter.splitSentence(item.getTitle()); - double[] inputs = new double[nn.getInputCount()]; + addNewCategoryInputs(item); + addNewTitleWordInputs(words); + addNewAuthorInputs(item); + double[] inputs = newArrayInitializedToNegativeOne(); inputs[biasInput] = 1; - //Initialize all inputs to -1 / false - for(int i = 0; i < inputs.length; i++) - { - inputs[i] = -1; - } + Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("UTC")); //Set month - Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("UTC")); - for(int i = 0; i < isNthMonthInput.length; i++) - { - if(cal.get(Calendar.MONTH) - cal.getMinimum(Calendar.MONTH) == i) - { - inputs[isNthMonthInput[i]] = 1; - } - } + inputs[isNthMonthInput[cal.get(Calendar.MONTH) - cal.getMinimum(Calendar.MONTH)]] = 1; //Set weekday - for(int i = 0; i < isNthWeekDayInput.length; i++) - { - if(cal.get(Calendar.DAY_OF_WEEK) - cal.getMinimum(Calendar.DAY_OF_WEEK) == i) - { - inputs[isNthMonthInput[i]] = 1; - } - } + inputs[isNthWeekDayInput[cal.get(Calendar.DAY_OF_WEEK) - cal.getMinimum(Calendar.DAY_OF_WEEK)]] = 1; - //Set day + //Set time int hourOfDay = cal.get(Calendar.HOUR_OF_DAY); if(hourOfDay > 6 && hourOfDay < 12) { @@ -120,9 +99,56 @@ public class FeedSorter implements Serializable{ inputs[categoryInputs.get(category.toLowerCase())] = 1; } + for(String word : words) + { + inputs[wordInputs.get(word)] = 1; + } + + if(item.getAuthor() != null) { + inputs[authorInputs.get(item.getAuthor().toLowerCase())] = 1; + } + return nn.computeOutput(inputs); } + private double[] newArrayInitializedToNegativeOne() { + double[] inputs = new double[nn.getInputCount()]; + Arrays.fill(inputs, 0, inputs.length, -1); + return inputs; + } + + private void addNewCategoryInputs(FeedItem item) { + for(String category : item.getCategory()) + { + category = category.toLowerCase(); + if(!categoryInputs.containsKey(category)) + { + categoryInputs.put(category, nn.addInput()); + } + } + } + + private void addNewAuthorInputs(FeedItem item) + { + if(item.getAuthor() != null) { + String author = item.getAuthor().toLowerCase(); + if (!authorInputs.containsKey(author)) { + authorInputs.put(author, nn.addInput()); + } + } + } + + private void addNewTitleWordInputs(List words) { + for(String word : words) + { + word = word.toLowerCase(); + if(!wordInputs.containsKey(word)) + { + wordInputs.put(word, nn.addInput()); + } + } + } + /** * Provides feedback to the neural network. * @param item The feeditem. @@ -167,12 +193,13 @@ public class FeedSorter implements Serializable{ */ public List sortItems(List items) { // Sort list based on something like date + nn.computeOutput() * DAY. - List newItems = new ArrayList(items); + final List newItems = new ArrayList(items); final Hashtable predictions = new Hashtable<>(); for(FeedItem feed : newItems) { - predictions.put(feed, getPrediction(feed)); + PredictionInterface prediction = getPrediction(feed); + predictions.put(feed, prediction); } Collections.sort(newItems, new Comparator() { diff --git a/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java b/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java index a2f06eb..68ff390 100755 --- a/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java +++ b/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java @@ -4,6 +4,7 @@ import java.io.Serializable; /** * Created by Jos on 14-5-2015. + * Is used to migitate the problem of neural networks ending up in the wrong local minimum. */ class MultiNeuralNetwork implements Serializable{ private static final long serialVersionUID = 0; diff --git a/app/src/main/java/org/rssin/neurons/NeuralNetwork.java b/app/src/main/java/org/rssin/neurons/NeuralNetwork.java index 620d5bd..7fe003f 100755 --- a/app/src/main/java/org/rssin/neurons/NeuralNetwork.java +++ b/app/src/main/java/org/rssin/neurons/NeuralNetwork.java @@ -26,9 +26,11 @@ class NeuralNetwork implements Serializable{ } public int addInput() { + assert hiddenNodes.length > 0; + int result = 0; - for (int i = 0; i < hiddenNodes.length; i++) { - result = hiddenNodes[i].addWeight(); + for (Neuron hiddenNode : hiddenNodes) { + result = hiddenNode.addWeight(); } return result; @@ -37,6 +39,7 @@ class NeuralNetwork implements Serializable{ public PredictionInterface computeOutput(double[] inputs) { double[] intermediateValues = new double[outputNode.getWeightCount()]; + //Output of hidden neurons for (int neuronNum = 0; neuronNum < hiddenNodes.length; neuronNum++) { Neuron n = hiddenNodes[neuronNum]; @@ -57,6 +60,7 @@ class NeuralNetwork implements Serializable{ } void learn(NeuralNetworkPrediction p, double expectedOutput) { + //TODO: See if adding momentum helps avoid local minima double actualOutput = p.getOutput(); double[] intermediateValues = p.getIntermediateValues(); double[] inputs = p.getInputs(); @@ -76,8 +80,13 @@ class NeuralNetwork implements Serializable{ hiddenGradients[i] = hiddenDerivative * outputGradient * outputNode.getWeight(i); } + updateWeights(intermediateValues, inputs, hiddenGradients, outputGradient); + } + + private void updateWeights(double[] intermediateValues, double[] inputs, double[] hiddenGradients, double outputGradient) { + final double learningRate = 0.2; + //Update input => hidden weights. - final double learningRate = 0.3; for (int neuronNum = 0; neuronNum < hiddenNodes.length; neuronNum++) { Neuron n = hiddenNodes[neuronNum]; diff --git a/app/src/main/java/org/rssin/neurons/Neuron.java b/app/src/main/java/org/rssin/neurons/Neuron.java index 724732d..d668ebc 100755 --- a/app/src/main/java/org/rssin/neurons/Neuron.java +++ b/app/src/main/java/org/rssin/neurons/Neuron.java @@ -23,7 +23,9 @@ class Neuron { } public int addWeight() { - weights.add(r.nextDouble() * 2 - 1); + // Initial values range from -.5 to .5. The exact value does not matter, + // as long as they aren't all 0. + weights.add(r.nextDouble() - .5); return weights.size() - 1; } diff --git a/app/src/main/java/org/rssin/neurons/SentenceSplitter.java b/app/src/main/java/org/rssin/neurons/SentenceSplitter.java new file mode 100755 index 0000000..fc5f46f --- /dev/null +++ b/app/src/main/java/org/rssin/neurons/SentenceSplitter.java @@ -0,0 +1,27 @@ +package org.rssin.neurons; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Created by Jos on 21-5-2015. + */ +public class SentenceSplitter { + public SentenceSplitter() + { } + + public List splitSentence(String sentence) + { + List allMatches = new ArrayList<>(); + Matcher m = Pattern.compile("[\\w-]+").matcher(sentence); + + while (m.find()) + { + allMatches.add(m.group().toLowerCase()); + } + + return allMatches; + } +} -- cgit v1.2.3