aboutsummaryrefslogtreecommitdiff
path: root/Assignment 4/packages/clusterPlot.py
diff options
context:
space:
mode:
authorCamil Staps2015-11-27 00:18:32 +0100
committerCamil Staps2015-11-27 00:18:32 +0100
commitd88d00232cfdbfd508834911af6ad89a217b84e1 (patch)
tree20308e5e89f76ce8f987598e26f75db6ad4cbd4e /Assignment 4/packages/clusterPlot.py
parentAssignment 3 report (diff)
Start assignment 4
Diffstat (limited to 'Assignment 4/packages/clusterPlot.py')
-rw-r--r--Assignment 4/packages/clusterPlot.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/Assignment 4/packages/clusterPlot.py b/Assignment 4/packages/clusterPlot.py
new file mode 100644
index 0000000..2f37a3d
--- /dev/null
+++ b/Assignment 4/packages/clusterPlot.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Mon Apr 14 09:01:18 2014
+
+"""
+
+def clusterPlot(X, clusterid, centroids='None', y='None', covars='None', figsize=(16,10)):
+ '''
+ CLUSTERPLOT Plots a clustering of a data set as well as the true class
+ labels. If data is more than 2-dimensional it should be first projected
+ onto the first two principal components. Data objects are plotted as a dot
+ with a circle around. The color of the dot indicates the true class,
+ and the cicle indicates the cluster index. Optionally, the centroids are
+ plotted as filled-star markers, and ellipsoids corresponding to covariance
+ matrices (e.g. for gaussian mixture models).
+
+ Usage:
+ clusterplot(X, clusterid)
+ clusterplot(X, clusterid, centroids=c_matrix, y=y_matrix)
+ clusterplot(X, clusterid, centroids=c_matrix, y=y_matrix, covars=c_tensor)
+
+ Input:
+ X N-by-M data matrix (N data objects with M attributes)
+ clusterid N-by-1 vector of cluster indices
+ centroids K-by-M matrix of cluster centroids (optional)
+ y N-by-1 vector of true class labels (optional)
+ covars M-by-M-by-K tensor of covariance matrices (optional)
+ '''
+ import numpy as np
+ from matplotlib.pyplot import figure, cm, plot, hold, legend, xlim, show
+
+
+ X = np.asarray(X)
+ cls = np.asarray(clusterid)
+ if y=='None':
+ y = np.zeros((X.shape[0],1))
+ else:
+ y = np.asarray(y)
+ if centroids!='None':
+ centroids = np.asarray(centroids)
+ K = np.size(np.unique(cls))
+ C = np.size(np.unique(y))
+ ncolors = np.max([C,K])
+
+ # plot data points color-coded by class, cluster markers and centroids
+ figure(figsize=figsize)
+ hold(True)
+ colors = [0]*ncolors
+ for color in range(ncolors):
+ colors[color] = cm.jet.__call__(color*1.0/(1.0*ncolors-1))[:3]
+ for i,cs in enumerate(np.unique(y)):
+ plot(X[(y==cs).ravel(),0], X[(y==cs).ravel(),1], 'o', markeredgecolor='k', markerfacecolor=colors[i],markersize=6, zorder=2)
+ for i,cr in enumerate(np.unique(cls)):
+ plot(X[(cls==cr).ravel(),0], X[(cls==cr).ravel(),1], 'o', markersize=12, markeredgecolor=colors[i], markerfacecolor='None', markeredgewidth=3, zorder=1)
+ if centroids!='None':
+ for cd in range(centroids.shape[0]):
+ plot(centroids[cd,0], centroids[cd,1], '*', markersize=22, markeredgecolor='k', markerfacecolor=colors[cd], markeredgewidth=2, zorder=3)
+ # plot cluster shapes:
+ if covars!='None':
+ for cd in range(centroids.shape[0]):
+ x1, x2 = gauss_2d(centroids[cd],covars[cd,:,:])
+ plot(x1,x2,'-', color=colors[cd], linewidth=3, zorder=5)
+ hold(False)
+
+ # create legend
+ legend_items = np.unique(y).tolist()+np.unique(cls).tolist()+np.unique(cls).tolist()
+ for i in range(len(legend_items)):
+ if i<C: legend_items[i] = 'Class: {0}'.format(legend_items[i]);
+ elif i<C+K: legend_items[i] = 'Cluster: {0}'.format(legend_items[i]);
+ else: legend_items[i] = 'Centroid: {0}'.format(legend_items[i]);
+ legend(legend_items, numpoints=1, markerscale=.75, prop={'size': 9})
+
+ xlim(X[:,0].min()*1.1, X[:,0].max()*1.2)
+
+ show()