diff options
author | Size43 | 2015-06-10 11:39:50 +0200 |
---|---|---|
committer | Size43 | 2015-06-10 11:39:50 +0200 |
commit | b8b188a42d39cc95013686a71459a792bf5bfa86 (patch) | |
tree | 789e5f116039d180200c34acd78290b3579629d9 /app/src | |
parent | Removed hardcoded strings & only one TrainingCase is stored per Article. (diff) |
Added comments to rssin.neurons.*
Diffstat (limited to 'app/src')
10 files changed, 149 insertions, 10 deletions
diff --git a/app/src/main/java/org/rssin/neurons/FeedSorter.java b/app/src/main/java/org/rssin/neurons/FeedSorter.java index 8887b99..660c546 100755 --- a/app/src/main/java/org/rssin/neurons/FeedSorter.java +++ b/app/src/main/java/org/rssin/neurons/FeedSorter.java @@ -19,6 +19,7 @@ import java.util.TimeZone; /**
* @author Jos.
+ * Sorts lists of feeds based on the output of a neural network, using date/time + title + author + category as inputs.
*/
public class FeedSorter implements Storable {
private static final long serialVersionUID = 0;
@@ -64,6 +65,9 @@ public class FeedSorter implements Storable { authorInputs = (Hashtable<String, Integer>) stream.readObject();
}
+ /**
+ * Initializes a new FeedSorter, and creates the basic inputs for the Neural network.4
+ */
public FeedSorter() {
biasInput = nn.addInput();
for (int i = 0; i < 12; i++) {
@@ -80,6 +84,11 @@ public class FeedSorter implements Storable { isNight = nn.addInput();
}
+ /**
+ * Returns a prediction for the provided FeedItem.
+ * @param item The item to predict the interest of the user of.
+ * @return The prediction.
+ */
private PredictionInterface getPrediction(FeedItem item) {
List<String> words = splitter.splitSentence(item.getTitle());
@@ -123,18 +132,31 @@ public class FeedSorter implements Storable { return nn.computeOutput(inputs);
}
+ /**
+ * @return a new array, the same size as the number of inputs, initialized to -1.
+ */
private double[] newArrayInitializedToNegativeOne() {
double[] inputs = new double[nn.getInputCount()];
Arrays.fill(inputs, 0, inputs.length, -1);
return inputs;
}
+ /**
+ * Adds new inputs to the Neural Network
+ * @param words The string identifiers of the inputs.
+ * @param map The map of string identifiers to input IDs. This map is modified.
+ */
private void addNewInputs(Iterable<String> words, Hashtable<String, Integer> map) {
for (String word : words) {
addNewInput(word, map);
}
}
+ /**
+ * Adds a single new input to the neural network.
+ * @param word The string identifier.
+ * @param map The map op string identifiers to input IDs. This map is modified.
+ */
private void addNewInput(String word, Hashtable<String, Integer> map) {
if (word != null) {
word = word.toLowerCase();
diff --git a/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java b/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java index ece53a8..b53588b 100755 --- a/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java +++ b/app/src/main/java/org/rssin/neurons/MultiNeuralNetwork.java @@ -21,6 +21,11 @@ class MultiNeuralNetwork implements Serializable { networks = SerializationTools.readArray(stream, NeuralNetwork.class);
}
+ /**
+ * Creates a new instance of the MultiNeuralNetwork
+ * @param numNetworks The number of underlying networks.
+ * @param numHiddenNodes The number of hidden nodes each network has.
+ */
public MultiNeuralNetwork(int numNetworks, int numHiddenNodes) {
networks = new NeuralNetwork[numNetworks];
for (int i = 0; i < networks.length; i++) {
@@ -28,15 +33,26 @@ class MultiNeuralNetwork implements Serializable { }
}
+ /**
+ * Calls addInput() on all NeuralNetworks.
+ * @return the new input ID.
+ */
public int addInput() {
int id = 0;
for (NeuralNetwork network : networks) {
id = network.addInput();
}
+ // Unless a neural network's addInput() is called individually, which should be impossible
+ // because the list of networks is private, all IDs returned by network.addInput() are the same.
return id;
}
+ /**
+ * Returns a prediction from the Neural Networks based on the inputs.
+ * @param inputs The list of double inputs.
+ * @return The prediction.
+ */
public PredictionInterface computeOutput(double[] inputs) {
PredictionInterface[] predictions = new PredictionInterface[networks.length];
for (int i = 0; i < predictions.length; i++) {
@@ -46,7 +62,12 @@ class MultiNeuralNetwork implements Serializable { return new MultiNeuralNetworkPrediction(predictions);
}
+ /**
+ * Returns the number of inputs in the neural network.
+ * @return
+ */
public int getInputCount() {
+ // All neural networks will have the same number of inputs. See addInput().
return networks[0].getInputCount();
}
}
diff --git a/app/src/main/java/org/rssin/neurons/MultiNeuralNetworkPrediction.java b/app/src/main/java/org/rssin/neurons/MultiNeuralNetworkPrediction.java index f618749..c03cab3 100755 --- a/app/src/main/java/org/rssin/neurons/MultiNeuralNetworkPrediction.java +++ b/app/src/main/java/org/rssin/neurons/MultiNeuralNetworkPrediction.java @@ -2,6 +2,7 @@ package org.rssin.neurons; /**
* @author Jos.
+ * A prediction that bases its result on the average of multiple other predictions.
*/
class MultiNeuralNetworkPrediction implements PredictionInterface {
private final PredictionInterface[] predictions;
@@ -14,6 +15,9 @@ class MultiNeuralNetworkPrediction implements PredictionInterface { this.predictions = predictions;
}
+ /**
+ * @return the prediction, as a value in [-1, 1].
+ */
public double getOutput() {
double average = 0;
for (PredictionInterface prediction : predictions) {
@@ -23,12 +27,19 @@ class MultiNeuralNetworkPrediction implements PredictionInterface { return average / (double) predictions.length;
}
+ /**
+ * Provides the neural networks with feedback.
+ * @param expectedOutput The expected output for the input values (getInputs()).
+ */
public void learn(double expectedOutput) {
for (PredictionInterface prediction : predictions) {
prediction.learn(expectedOutput);
}
}
+ /**
+ * @return The inputs provided to the neural network.
+ */
public double[] getInputs() {
return predictions[0].getInputs();
}
diff --git a/app/src/main/java/org/rssin/neurons/NeuralNetwork.java b/app/src/main/java/org/rssin/neurons/NeuralNetwork.java index c990be9..446e13a 100755 --- a/app/src/main/java/org/rssin/neurons/NeuralNetwork.java +++ b/app/src/main/java/org/rssin/neurons/NeuralNetwork.java @@ -9,6 +9,7 @@ import java.io.Serializable; /**
* @author Jos.
+ * A 'simple' implementation of a neural network with one hidden layer.
*/
class NeuralNetwork implements Serializable {
private static final long serialVersionUID = 0;
@@ -25,6 +26,10 @@ class NeuralNetwork implements Serializable { outputNode = (Neuron) stream.readObject();
}
+ /**
+ * Creates a new instance of the neural network.
+ * @param numHiddenNodes The number of hidden nodes in the neural network.
+ */
NeuralNetwork(int numHiddenNodes) {
if (numHiddenNodes < 1) {
throw new IllegalArgumentException("numHiddenNodes must be > 0");
@@ -39,6 +44,9 @@ class NeuralNetwork implements Serializable { outputNode = new Neuron(numHiddenNodes + 1);
}
+ /**
+ * @return the new input ID.
+ */
@SuppressLint("Assert")
int addInput() {
assert hiddenNodes.length > 0;
@@ -51,6 +59,11 @@ class NeuralNetwork implements Serializable { return result;
}
+ /**
+ * Calculcates the output of the neural network for the given input values.
+ * @param inputs The inputs for the neural network.
+ * @return The output of the neural network.
+ */
PredictionInterface computeOutput(double[] inputs) {
double[] intermediateValues = new double[outputNode.getWeightCount()];
@@ -74,8 +87,12 @@ class NeuralNetwork implements Serializable { return new NeuralNetworkPrediction(this, inputs, intermediateValues, HyperTan(result));
}
+ /**
+ * Train the neural network
+ * @param p The original prediction
+ * @param expectedOutput The expected output of the prediction.
+ */
void learn(NeuralNetworkPrediction p, double expectedOutput) {
- //TODO: See if adding momentum helps avoid local minimum
double actualOutput = p.getOutput();
double[] intermediateValues = p.getIntermediateValues();
double[] inputs = p.getInputs();
@@ -97,6 +114,13 @@ class NeuralNetwork implements Serializable { updateWeights(intermediateValues, inputs, hiddenGradients, outputGradient);
}
+ /**
+ * Method to update the weights of the nodes.
+ * @param intermediateValues The intermediate values generated by the hidden nodes
+ * @param inputs The input values
+ * @param hiddenGradients The hidden gradients, calculated in learn().
+ * @param outputGradient The output gradients, calculated in learn().
+ */
private void updateWeights(double[] intermediateValues, double[] inputs, double[] hiddenGradients, double outputGradient) {
final double learningRate = 0.2;
@@ -119,12 +143,20 @@ class NeuralNetwork implements Serializable { }
}
+ /**
+ * HyperTan that returns -1 or 1 when the value is smaller than -10 or bigger than 10, respectively.
+ * @param x The input
+ * @return The result of HyperTan.
+ */
private static double HyperTan(double x) {
if (x < -10.0) return -1.0;
else if (x > 10.0) return 1.0;
else return Math.tanh(x);
}
+ /**
+ * @return The number of inputs for this neural network.
+ */
int getInputCount() {
return hiddenNodes[0].getWeightCount();
}
diff --git a/app/src/main/java/org/rssin/neurons/NeuralNetworkPrediction.java b/app/src/main/java/org/rssin/neurons/NeuralNetworkPrediction.java index 169caee..40dca02 100755 --- a/app/src/main/java/org/rssin/neurons/NeuralNetworkPrediction.java +++ b/app/src/main/java/org/rssin/neurons/NeuralNetworkPrediction.java @@ -2,6 +2,7 @@ package org.rssin.neurons; /**
* @author Jos.
+ * A prediction made by the neural network.
*/
class NeuralNetworkPrediction implements PredictionInterface {
private final double[] inputs;
@@ -16,18 +17,31 @@ class NeuralNetworkPrediction implements PredictionInterface { this.nn = nn;
}
+ /**
+ * @return The inputs that resulted in this prediction.
+ */
public double[] getInputs() {
return inputs;
}
+ /**
+ * @return The intermediate values that resulted in the output, as calculated by NeuralNetwork.computeOutput.
+ */
public double[] getIntermediateValues() {
return intermediateValues;
}
+ /**
+ * @return The output of the neural network.
+ */
public double getOutput() {
return output;
}
+ /**
+ * Provides feedback to the neural network.
+ * @param expectedOutput The expected output for the input values.
+ */
public void learn(double expectedOutput) {
nn.learn(this, expectedOutput);
}
diff --git a/app/src/main/java/org/rssin/neurons/Neuron.java b/app/src/main/java/org/rssin/neurons/Neuron.java index 203a450..760ea77 100755 --- a/app/src/main/java/org/rssin/neurons/Neuron.java +++ b/app/src/main/java/org/rssin/neurons/Neuron.java @@ -10,6 +10,7 @@ import java.util.Random; /**
* @author Jos.
+ * A class that contains the weights for a single neuron in the neural network.
*/
class Neuron implements Serializable {
private static final long serialVersionUID = 0;
@@ -35,6 +36,10 @@ class Neuron implements Serializable { }
}
+ /**
+ * Adds a weight to the neuron.
+ * @return The index of the weight.
+ */
public int addWeight() {
// Initial values range from -.5 to .5. The exact value does not matter,
// as long as they aren't all 0.
@@ -42,14 +47,27 @@ class Neuron implements Serializable { return weights.size() - 1;
}
+ /**
+ * Returns the value of a weight
+ * @param i The weight index
+ * @return The value of the weight
+ */
public double getWeight(int i) {
return weights.get(i);
}
+ /**
+ * Modifies a weight, by adding delta to its value
+ * @param i The weight index
+ * @param delta The amount the value will change.
+ */
public void adjustWeight(int i, double delta) {
weights.set(i, weights.get(i) + delta);
}
+ /**
+ * @return the number of weights in this neuron.
+ */
public int getWeightCount() {
return weights.size();
}
diff --git a/app/src/main/java/org/rssin/neurons/PredictionInterface.java b/app/src/main/java/org/rssin/neurons/PredictionInterface.java index ff46992..891707a 100755 --- a/app/src/main/java/org/rssin/neurons/PredictionInterface.java +++ b/app/src/main/java/org/rssin/neurons/PredictionInterface.java @@ -2,11 +2,23 @@ package org.rssin.neurons; /**
* @author Jos.
+ * Interface for predictions generated by either NeuralNetwork or MultiNeuralNetwork.
*/
-interface PredictionInterface {
+interface PredictionInterface
+{
+ /**
+ * @return The output of the neural network.
+ */
double getOutput();
+ /**
+ * Provides feedback to the neural network.
+ * @param expectedOutput The expected output for this prediction.
+ */
void learn(double expectedOutput);
+ /**
+ * @return The inputs that were given to the neural network.
+ */
double[] getInputs();
}
diff --git a/app/src/main/java/org/rssin/neurons/SentenceSplitter.java b/app/src/main/java/org/rssin/neurons/SentenceSplitter.java index 29e34bc..002071e 100755 --- a/app/src/main/java/org/rssin/neurons/SentenceSplitter.java +++ b/app/src/main/java/org/rssin/neurons/SentenceSplitter.java @@ -12,7 +12,8 @@ import java.util.regex.Pattern; public class SentenceSplitter implements Serializable {
private static final long serialVersionUID = 0;
- private final Pattern wordMatch = Pattern.compile("[\\w-]+");//For unicode support, add the Pattern.UNICODE_CHARACTER_CLASS flag. Works only in Java 7+.
+ //For unicode support, add the Pattern.UNICODE_CHARACTER_CLASS flag. Works only in Java 7+, currently not supported on Android.
+ private final Pattern wordMatch = Pattern.compile("[\\w-]+");
public SentenceSplitter() {
}
diff --git a/app/src/main/java/org/rssin/neurons/TrainingCase.java b/app/src/main/java/org/rssin/neurons/TrainingCase.java index 5c6172c..d1fee3a 100755 --- a/app/src/main/java/org/rssin/neurons/TrainingCase.java +++ b/app/src/main/java/org/rssin/neurons/TrainingCase.java @@ -4,6 +4,7 @@ import java.io.Serializable; /**
* @author Jos.
+ * A training case for the FeedSorter.
*/
public class TrainingCase implements Serializable {
private static final long serialVersionUID = 0;
@@ -15,14 +16,23 @@ public class TrainingCase implements Serializable { this.feedback = feedback;
}
+ /**
+ * @return The inputs that were given to the neural network.
+ */
double[] getInputs() {
return inputs;
}
+ /**
+ * @return The expected prediction
+ */
public Feedback getFeedback() {
return feedback;
}
+ /**
+ * @param feedback The new expected prediction.
+ */
public void setFeedback(Feedback feedback)
{
this.feedback = feedback;
diff --git a/app/src/main/java/org/rssin/rssin/FeedLoaderAndSorter.java b/app/src/main/java/org/rssin/rssin/FeedLoaderAndSorter.java index a317a27..524ac46 100755 --- a/app/src/main/java/org/rssin/rssin/FeedLoaderAndSorter.java +++ b/app/src/main/java/org/rssin/rssin/FeedLoaderAndSorter.java @@ -16,11 +16,15 @@ import java.util.List; /**
* Created by Jos on 20-5-2015.
- * @todo javadoc
+ * A class that loads & sorts FeedItems from a list of Feeds.
*/
public class FeedLoaderAndSorter {
private final List<Feed> feeds;
+ /**
+ * Creates a new FeedLoaderAndSorter that can load FeedItems the list of feeds.
+ * @param feeds The list of Feeds to load FeedItems from.
+ */
public FeedLoaderAndSorter(List<Feed> feeds) {
this.feeds = feeds;
}
@@ -81,10 +85,4 @@ public class FeedLoaderAndSorter { return count == 0;
}
}
-
- private static boolean contains(String haystack, String needle) {
- return haystack != null
- && needle != null
- && haystack.toLowerCase().contains(needle.toLowerCase());
- }
}
|