Wine Classification Using Linear Discriminant Analysis with Python and SciKit-Learn

In this post, a classifier is constructed which determines to which cultivar a specific wine sample belongs. Each sample consists a vector \textbf{v} of 13 attributes of the wine, that is \textbf{v} \in \mathbb{R}^{13}. The attributes are as follows:

  1. Alcohol
  2. Malic acid
  3. Ash
  4. Alcalinity of ash
  5. Magnesium
  6. Total phenols
  7. Flavanoids
  8. Nonflavanoid phenols
  9. Proanthocyanins
  10. Color intensity
  11. Hue
  12. OD280/OD315 of diluted wines
  13. Proline

Based on these attributes, the goal is to identify from which of three cultivars the data originated. The data set is available at the UCI Machine Learning Repository. Below are shown three sample rows from the data set.


The first column denotes the target class. This data can be read into a matrix using the loadtxt function from numpy.

import numpy as np
import matplotlib.pyplot as plt
from sklearn import cross_validation
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from matplotlib import colors
# %% Read data from csv file
A = np.loadtxt('', delimiter=',')
#Get the targets (first column of file)
y = A[:, 0]
#Remove targets from input data
A = A[:, 1:]

Linear Discriminant Analysis

The purpose of linear discriminant analysis (LDA) is to estimate the probability that a sample belongs to a specific class given the data sample itself. That is to estimate Pr(C=c_{i} | X=x), where C=\{c_{1}, c_{2}, \ldots, c_{m}\} is the set of class identifiers, X is the domain, and x is the specific sample. Applying Bayes Theorem results in:

Pr(C=c_{i} | X=x)=\frac{Pr(X=x | C=c_{i})Pr(C=c_{i})}{\sum_{j=1}^{m}{Pr(X=x | C=c_{j})Pr(C=c_{j})}}.

Pr(C=c_{i}) can be estimated by the frequency of class c_{i} in the training data. LDA assumes that each class can be modeled as a multivariate Gaussian distribution with each class sharing a common covariance matrix \boldsymbol{\Sigma}. That is:

Pr(X=x | C=c_{i})=\frac{1}{\sqrt{2\pi}|\boldsymbol{\Sigma}|^{1/2}}e^{-\frac{1}{2}(x-\mu_{c_{i}})^{T}\boldsymbol{\Sigma}^{-1}(x-\mu_{c_{i}})},

where \mu_{c_{i}} and \boldsymbol{\Sigma} are the mean vector and covariance matrix for class c_{i} respectively. The shared covariance matrix and mean vectors are estimated from the training data.

LDA Dimensionality Reduction

The centroids of the m classes lie within an affine subspace of dimension at most m-1. The input data can be transformed into a lower dimension that is optimal in terms of LDA classification. An optimal subspace is defined as one in which the between-class variance is maximized relative to the within-class variance. That is, the amount of overlap between the classes is minimized with respect to the variance of the class centroids and the shared covariance matrix. This can be represented by maximizing the Rayleigh quotient:



\textbf{W}=\sum\limits_{i=1}^{m}{\sum\limits_{j \in c_{i}}{(\textbf{x}_{j}-\boldsymbol{\mu}_{c_i})(\textbf{x}_{j}-\boldsymbol{\mu}_{c_i})^{T}}},

is the within-class scatter matrix,


is the between-class scatter matrix, N_{i} is the number of samples belonging to class c_{i}, and \overline{\textbf{x}} is the mean vector of all input vectors. The solution to this generalized eigenvalue problem is given by the largest eigenvalue of the matrix \textbf{W}^{-1}\textbf{B}; the corresponding eigenvector being the solution vector \textbf{a}. This computation along with the dimension reduction can easily be performed using scikit-learn as follows:

lda = LinearDiscriminantAnalysis(n_components=2), y)
drA = lda.transform(A)

As there are m=3 classes in this example, the data is transformed from \mathbb{R}^{13} to \mathbb{R}^{2} by preserving 2 components corresponding to the 2 largest eigenvalues of \textbf{W}^{-1}\textbf{B}. A plot of the transformed data is shown next, with classes denoted with different colors.

wineplt1Figure 1: Transformed Data Plot

LDA Classification

A transformed data point can be classified by identifying the class centroid \boldsymbol{\mu}_{c_i} to which it is closest in the transformed space. The centroids of the input data are shown below (as large black points) along with the transformed data plotted on the Voronoi diagram induced by the centroids. The Voronoi cells correspond to the regions LDA will classify as belonging to the respective centroid’s class.wineplt2

Figure 2: Transformed Data Plot with Centroids and Voronoi Cells

As can be seen, there is clear separation between the three classes of wine in this case and so the classifier is expected to perform very well.


Cross validation is used to test the performance of the classifier. The input data set \textbf{V} is split into two sets T_{1} and T_{2} such that T_{1} \cap T_{2} = \emptyset and  T_{1} \cup T_{2} = \textbf{V}. A larger percentage of the data is allocated for training. This process is repeated k times and the classifier is trained and scored each time. This can be accomplished in python using scikit-learn as follows:

# %% Data extracted; perform LDA
lda = LinearDiscriminantAnalysis()
k_fold = cross_validation.KFold(len(A), 3, shuffle=True)
print('LDA Results: ')
for (trn, tst) in k_fold:[trn], y[trn])
    outVal = lda.score(A[tst], y[tst])
    #Compute classification error
print('Score: ' + str(outVal))

Results from the three runs are as follows:

  • Run 1: 1.0
  • Run 2: 0.983050847458
  • Run 3: 0.966101694915

As can be seen, the classifier is able to predict the correct cultivar for a wine sample with high accuracy due to the well behaved structure of the classes.


  1. Hastie, Trevor, et al. “The elements of statistical learning: data mining, inference and prediction.” The Mathematical Intelligencer 27.2 (2005): 83-85.

2 thoughts on “Wine Classification Using Linear Discriminant Analysis with Python and SciKit-Learn

  1. Pingback: Text Mining Online Reviews for Sentiment Analysis – nicholastsmith

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s