aboutsummaryrefslogtreecommitdiff
path: root/Assignment 3/ex32.py
diff options
context:
space:
mode:
Diffstat (limited to 'Assignment 3/ex32.py')
-rw-r--r--Assignment 3/ex32.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/Assignment 3/ex32.py b/Assignment 3/ex32.py
new file mode 100644
index 0000000..40e4c6a
--- /dev/null
+++ b/Assignment 3/ex32.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Oct 23 15:11:39 2015
+
+@author: Camil Staps, s4498062
+
+This is Python 2 code.
+"""
+
+import matplotlib.pyplot as plt
+from scipy import io as sciio
+from sklearn import tree
+from sklearn import cross_validation
+
+# 3.2.1
+wine = sciio.loadmat('./Data/wine.mat')
+data = wine['X']
+clss = wine['y']
+classNames = [str(n[0][0]) for n in wine['classNames']]
+
+X_train, X_test, y_train, y_test = cross_validation.train_test_split(data, clss)
+
+depths = range(2,21)
+optimal_depth, max_score, scores = 0, 0, []
+for depth in depths:
+ clf = tree.DecisionTreeClassifier(max_depth=depth, criterion='gini')
+ clf = clf.fit(X_train, y_train)
+ score = clf.score(X_test, y_test)
+ scores.append(score)
+ if score > max_score:
+ max_score, optimal_depth = score, depth
+
+print(optimal_depth, max_score)
+plt.plot(depths, scores, label='Holdout CV')
+
+# 3.2.2
+k = 10
+depths = range(2,21)
+
+optimal_depth, max_score, scores = 0, 0, []
+kf = cross_validation.KFold(len(data), k)
+for depth in depths:
+ temp_scores = []
+ for train, test in kf:
+ X_train, X_test = [data[i] for i in train], [data[i] for i in test]
+ y_train, y_test = [clss[i] for i in train], [clss[i] for i in test]
+
+ clf = tree.DecisionTreeClassifier(max_depth=depth, criterion='gini')
+ clf = clf.fit(X_train, y_train)
+ score = clf.score(X_test, y_test)
+ temp_scores.append(score)
+
+ score = np.mean(temp_scores)
+ scores.append(score)
+ if score > max_score:
+ max_score, optimal_depth = score, depth
+
+print(optimal_depth, max_score)
+plt.plot(depths, scores, label=str(k) + '-fold CV')
+
+plt.ylabel('Classification error')
+plt.xlabel('Tree depth')
+plt.legend(loc=4)
+plt.grid()
+plt.show()