From 3fd81c73cfd8bad36b2a1cf7955006e35c1d9db5 Mon Sep 17 00:00:00 2001 From: Camil Staps Date: Fri, 23 Oct 2015 16:44:24 +0200 Subject: Assignment 3: code, plots --- Assignment 3/ex32.py | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 Assignment 3/ex32.py (limited to 'Assignment 3/ex32.py') diff --git a/Assignment 3/ex32.py b/Assignment 3/ex32.py new file mode 100644 index 0000000..40e4c6a --- /dev/null +++ b/Assignment 3/ex32.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Oct 23 15:11:39 2015 + +@author: Camil Staps, s4498062 + +This is Python 2 code. +""" + +import matplotlib.pyplot as plt +from scipy import io as sciio +from sklearn import tree +from sklearn import cross_validation + +# 3.2.1 +wine = sciio.loadmat('./Data/wine.mat') +data = wine['X'] +clss = wine['y'] +classNames = [str(n[0][0]) for n in wine['classNames']] + +X_train, X_test, y_train, y_test = cross_validation.train_test_split(data, clss) + +depths = range(2,21) +optimal_depth, max_score, scores = 0, 0, [] +for depth in depths: + clf = tree.DecisionTreeClassifier(max_depth=depth, criterion='gini') + clf = clf.fit(X_train, y_train) + score = clf.score(X_test, y_test) + scores.append(score) + if score > max_score: + max_score, optimal_depth = score, depth + +print(optimal_depth, max_score) +plt.plot(depths, scores, label='Holdout CV') + +# 3.2.2 +k = 10 +depths = range(2,21) + +optimal_depth, max_score, scores = 0, 0, [] +kf = cross_validation.KFold(len(data), k) +for depth in depths: + temp_scores = [] + for train, test in kf: + X_train, X_test = [data[i] for i in train], [data[i] for i in test] + y_train, y_test = [clss[i] for i in train], [clss[i] for i in test] + + clf = tree.DecisionTreeClassifier(max_depth=depth, criterion='gini') + clf = clf.fit(X_train, y_train) + score = clf.score(X_test, y_test) + temp_scores.append(score) + + score = np.mean(temp_scores) + scores.append(score) + if score > max_score: + max_score, optimal_depth = score, depth + +print(optimal_depth, max_score) +plt.plot(depths, scores, label=str(k) + '-fold CV') + +plt.ylabel('Classification error') +plt.xlabel('Tree depth') +plt.legend(loc=4) +plt.grid() +plt.show() -- cgit v1.2.3