• Home
  • Tutoring
  • About
  • Why Me
  • FAQ
Contact
  • Blog

Introduction to Data Analysis and Visualisation

Data Science

|

April 29, 2024

Learn the basics of data analysis and visualisation using Python with Pandas, Matplotlib and Seaborn libraries!

Recently I was tutoring a student to help them with their data science coursework within their business degree, and I thought it would be useful to share some of the lessons here. As an example, we’ll be using supermarket sales data (modified from Kaggle) and cover the basics of data preprocessing and data analysis with Python and Pandas, and data visualisation with Matplotlib and Seaborn. By the end of the post you should know how to load and filter your data and do some basic statistics and visualisations with it.

Data Preprocessing with Pandas

Let’s start by importing the Pandas and Matplotlib libraries:

import pandas as pd
import matplotlib.pyplot as plt

Next, we import the datasets into Pandas dataframes using the Pandas read_csv() function:

df_sales = pd.read_csv('supermarket_sales_data.csv')
df_product = pd.read_csv('supermarket_product_data.csv')

Now let’s visualise the first few rows from each dataset using a dataset’s head() function, to get an idea of data in each dataset:

df_sales.head()
Invoice IDBranchCityCustomer typeGenderTax 5%TotalDateTimePaymentcogsgross margin percentagegross incomeRating
0750-67-8428AYangonMemberFemale26.1415548.97151/5/201913:08Ewallet522.834.76190526.14159.1
1226-31-3081CNaypyitawNormalFemale3.820080.22003/8/201910:29Cash76.404.7619053.82009.6
2631-41-3108AYangonNormalMale16.2155340.52553/3/201913:23Credit card324.314.76190516.21557.4
3123-19-1176AYangonMemberMale23.2880489.04801/27/201920:33Ewallet465.764.76190523.28808.4
4373-73-7910AYangonNormalMale30.2085634.37852/8/201910:37Ewallet604.174.76190530.20855.3
df_product.head()
Invoice IDProduct lineUnit priceQuantity
0750-67-8428Health and beauty74.697.0
1226-31-3081Electronic accessories15.285.0
2631-41-3108Home and lifestyle46.337.0
3123-19-1176Health and beauty58.228.0
4373-73-7910Sports and travel86.317.0

Because both datasets contain different information about each transaction, let’s merge them into a single dataset by the Invoice ID using the Pandas merge() function:

df = pd.merge(df_sales, df_product, on='Invoice ID')

We can now display the columns of this merged dataset using a dataframe’s info() function, and we can see that the product and the sales columns have now been combined:

df.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1000 entries, 0 to 999
Data columns (total 17 columns):
 #   Column                   Non-Null Count  Dtype  
---  ------                   --------------  -----  
 0   Invoice ID               1000 non-null   object 
 1   Branch                   1000 non-null   object 
 2   City                     1000 non-null   object 
 3   Customer type            1000 non-null   object 
 4   Gender                   1000 non-null   object 
 5   Tax 5%                   1000 non-null   float64
 6   Total                    1000 non-null   float64
 7   Date                     1000 non-null   object 
 8   Time                     1000 non-null   object 
 9   Payment                  1000 non-null   object 
 10  cogs                     1000 non-null   float64
 11  gross margin percentage  1000 non-null   float64
 12  gross income             1000 non-null   float64
 13  Rating                   1000 non-null   float64
 14  Product line             1000 non-null   object 
 15  Unit price               950 non-null    float64
 16  Quantity                 975 non-null    float64
dtypes: float64(8), object(9)
memory usage: 140.6+ KB

However, if we now look at the Non-Null Count results in the info() output, we can see that Unit price and Quantity columns each have a rows with missing data (50 and 25 rows respectively). To handle missing data, one option is simply to delete all the rows that contain any missing data using the dropna() function:

df = df.dropna()

If we now show the dataset info, we can see that all columns have the same number (950) of non-null rows:

df.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 950 entries, 0 to 999
Data columns (total 17 columns):
 #   Column                   Non-Null Count  Dtype  
---  ------                   --------------  -----  
 0   Invoice ID               950 non-null    object 
 1   Branch                   950 non-null    object 
 2   City                     950 non-null    object 
 3   Customer type            950 non-null    object 
 4   Gender                   950 non-null    object 
 5   Tax 5%                   950 non-null    float64
 6   Total                    950 non-null    float64
 7   Date                     950 non-null    object 
 8   Time                     950 non-null    object 
 9   Payment                  950 non-null    object 
 10  cogs                     950 non-null    float64
 11  gross margin percentage  950 non-null    float64
 12  gross income             950 non-null    float64
 13  Rating                   950 non-null    float64
 14  Product line             950 non-null    object 
 15  Unit price               950 non-null    float64
 16  Quantity                 950 non-null    float64
dtypes: float64(8), object(9)
memory usage: 133.6+ KB

If we want to keep the rows containing missing data, another option is also to replace the missing data with statistic such as the column mean using the fillna() function:

df['Unit price'].fillna(df['Unit price'].mean(), inplace=True)
df['Quantity'].fillna(df['Quantity'].mean(), inplace=True)

Displaying the dataset info now shows that we kept all 1000 rows and now there are no missing data:

df.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1000 entries, 0 to 999
Data columns (total 17 columns):
 #   Column                   Non-Null Count  Dtype  
---  ------                   --------------  -----  
 0   Invoice ID               1000 non-null   object 
 1   Branch                   1000 non-null   object 
 2   City                     1000 non-null   object 
 3   Customer type            1000 non-null   object 
 4   Gender                   1000 non-null   object 
 5   Tax 5%                   1000 non-null   float64
 6   Total                    1000 non-null   float64
 7   Date                     1000 non-null   object 
 8   Time                     1000 non-null   object 
 9   Payment                  1000 non-null   object 
 10  cogs                     1000 non-null   float64
 11  gross margin percentage  1000 non-null   float64
 12  gross income             1000 non-null   float64
 13  Rating                   1000 non-null   float64
 14  Product line             1000 non-null   object 
 15  Unit price               1000 non-null   float64
 16  Quantity                 1000 non-null   float64
dtypes: float64(8), object(9)
memory usage: 140.6+ KB

Next, the dataset may contain outliers, that is data points that are very far from the typical samples from the data distribution (for example, arising from mistakes when manually logging the data). To visualise these we can use a box plot for the relevant column, in this case Quantity:

plt.boxplot(df['Quantity'])
plt.title('Box Plot of Quantity')
plt.ylabel('Quantity')
plt.grid(True)
plt.show()

A box plot visualises the data’s median (orange line), the 25th and 75th percentiles (edges of the box) and each quartile plus or minus 1.5 times the interquartile range (IQR, the “whiskers”). Any data points that are outside of the whiskers are visualised by circles and denote the outliers — in this case there are some outliers around Quantity=20. We can get the values of the outliers using the loc() function where we select all rows of the Quantity column where the points are outside of the 1.5 * IQR range:

Q1 = df['Quantity'].quantile(0.25)
Q3 = df['Quantity'].quantile(0.75)
IQR = Q3 - Q1
outliers = ((df['Quantity'] < (Q1 - 1.5 * IQR)) | (df['Quantity'] > (Q3 + 1.5 * IQR)))
df.loc[outliers, 'Quantity']
93     20.0
100    20.0
562    20.0
692    20.0
729    20.0
Name: Quantity, dtype: float64

In this case we have 5 outliers that are all 20. Now we again have a few options how to handle outliers — one is simply to delete all the corresponding rows:

df = df[~outliers]

Another option is to cap the value of the outliers to some value such as the 99th percentile:

df.loc[outliers, 'Quantity'] = df['Quantity'].quantile(0.99)
df.loc[outliers, 'Quantity']
93     10.0
100    10.0
562    10.0
692    10.0
729    10.0
Name: Quantity, dtype: float64

Data Analysis with Pandas

Now that we have preprocessed the data by merging the datasets and handing missing data and outliers, we are ready to do some data analysis. We can get basic statistics about the columns in the dataset by using the describe() function:

df.describe()
Tax 5%Totalcogsgross margin percentagegross incomeRatingUnit priceQuantity
count950.000000950.000000950.0000009.500000e+02950.000000950.000000950.000000950.000000
mean15.416576323.748103308.3315264.761905e+0015.4165766.98642156.1303785.526316
std11.769669247.163048235.3933796.042795e-1411.7696691.71109227.2685582.926720
min0.50850010.67850010.1700004.761905e+000.5085004.00000010.0800001.000000
25%5.789250121.574250115.7850004.761905e+005.7892505.60000032.6425003.000000
50%12.218750256.593750244.3750004.761905e+0012.2187507.00000055.6200005.000000
75%22.687375476.434875453.7475004.761905e+0022.6873758.50000078.3800008.000000
max49.6500001042.650000993.0000004.761905e+0049.65000010.000000135.64981910.000000

We can also get specific statistics like the mean, median and standard deviation for particular columns as:

print(df['Quantity'].mean())
print(df['Unit price'].median())
print(df['Total'].std())
5.526315789473684
55.620000000000005
247.1630476392563

To compare different subgroups in the data, such as the mean values of each column for member and nonmember customers, we can use the groupby() function:

df.groupby('Customer type').mean()
Tax 5%Totalcogsgross margin percentagegross incomeRatingUnit priceQuantity
Customer type
Member15.744493330.634348314.8898554.76190515.7444936.95518756.9865135.603734
Normal15.078850316.655859301.5770094.76190515.0788507.01859055.2486325.446581

Or if we want to compare the sales across different supermarket branches, we can group the sum of Total by the Branch using subgroup as:

df.groupby('Branch')['Total'].sum()
Branch
A     99478.3125
B    102748.3275
C    105334.0575
Name: Total, dtype: float64

If we now want to count how many of each value a particular column has, such as how many sales there have been of each different product line, we can do this with the value_counts() function:

df['Product line'].value_counts()
Fashion accessories       170
Food and beverages        169
Electronic accessories    159
Sports and travel         157
Home and lifestyle        152
Health and beauty         143
Name: Product line, dtype: int64

To determine how different variables are related to each other we can print the correlation matrix for the dataset as:

df.corr()
Tax 5%Totalcogsgross margin percentagegross incomeRatingUnit priceQuantity
Tax 5%1.0000001.0000001.000000NaN1.000000-0.0389680.6315430.697840
Total1.0000001.0000001.000000NaN1.000000-0.0389680.6315430.697840
cogs1.0000001.0000001.000000NaN1.000000-0.0389680.6315430.697840
gross margin percentageNaNNaNNaNNaNNaNNaNNaNNaN
gross income1.0000001.0000001.000000NaN1.000000-0.0389680.6315430.697840
Rating-0.038968-0.038968-0.038968NaN-0.0389681.000000-0.006769-0.023632
Unit price0.6315430.6315430.631543NaN0.631543-0.0067691.0000000.042173
Quantity0.6978400.6978400.697840NaN0.697840-0.0236320.0421731.000000

Data Visualisation with Matplotlib

Let’s now visualise our data using the Matplotlib library. Matplotlib can generate different types of plots, such as line plots for visualising continuous data, scatter plots for investigating the relationship between variables, bar charts for categorical variables, and pie charts for visualising proportions of data. Let’s say we first want to visualise the total sales as a function of time. We can do this by first converting the Date column of the dataset from a string into a Pandas datetime format and creating a new Month column containing the month of each sale. We can then display the sum total of the sales by month using the groupby() function as:

df['Date'] = pd.to_datetime(df['Date'])
df['Month'] = df['Date'].dt.month
monthly_sales = df.groupby('Month')['Total'].sum()
monthly_sales
Month
1    111017.1195
2     92414.6580
3    104128.9200
Name: Total, dtype: float64

To visualise this, we first create a Matplotlib plot using the figure() function and then call the plot() function on the dataset column we want to visualise, in this case monthly_sales. We can then set the plot title, axis labels and show the grid. We can see that the sales drop slightly in February but recover by March:

plt.figure(figsize=(5, 3))
monthly_sales.plot(kind='line')
plt.title('Monthly Sales Trends')
plt.xlabel('Month')
plt.ylabel('Total Sales')
plt.grid(True)
plt.show()

Let’s say we now want to compare the sum total of sales across different product lines. We can do that by first grouping the sum of the Total column by the product line, and then sorting the sales as:

sales_by_product_line = df.groupby('Product line')['Total'].sum().sort_values()
sales_by_product_line
Product line
Health and beauty         46398.5445
Electronic accessories    50201.9595
Home and lifestyle        50918.9520
Fashion accessories       52140.1020
Sports and travel         52874.6610
Food and beverages        55026.4785
Name: Total, dtype: float64

We can now plot a bar chart visualising the sales for each category by calling the plot function on the previous data, now with the parameter ‘kind’ set to ‘bar’. We can rotate the x axis labels slightly and then compress the plot using the tight_layout() function:

plt.figure(figsize=(5, 3))
sales_by_product_line.plot(kind='bar')
plt.title('Sales by Product Line')
plt.xlabel('Product Line')
plt.ylabel('Total Sales')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

Let’s say we now wish to investigate if higher priced products generate more total sales. We can do that by using a scatter plot showing the total sales as a function of the unit price using the scatter_plot() function. In this case we can see that higher priced products lead to more total sales, as well as a few outliers where the unit price is higher than expected:

plt.figure(figsize=(5, 3))
plt.scatter(df['Unit price'], df['Total'], alpha=0.5, color='green')
plt.title('Relationship between Total Sales and Unit Price')
plt.xlabel('Unit Price')
plt.ylabel('Total Sales')
plt.grid(True)
plt.show()

Finally, let’s say we are interested in comparing the sales for member and nonmember customers to evaluate the effectiveness of membership. We can do this by plotting a pie chart of the total sales grouped by the customer type as follows and see that member customers lead to more total sales than nonmembers:

customer_sales = df.groupby('Customer type')['Total'].sum()
plt.figure(figsize=(5, 3))
customer_sales.plot(kind='pie')
plt.title('Proportion of Sales by Customer Type')
plt.tight_layout()
plt.show()

Advanced Visualisations with Seaborn

We can create more advanced and better looking plots of data with the Seaborn library that builds on top of Matplotlib. For example, we can create violin plots showing the density of a variable, pair plots showing the relationships between multiple variables, and heatmap plots for visualising correlations. Firstly, we can create a violin plot that visualises the distibutions of variables, in this case comparing the distributions of unit price across different branches:

plt.figure(figsize=(5, 3))
sns.violinplot(x='Branch', y='Unit price', data=df)
plt.title('Distribution of Unit Prices by Branch')
plt.xlabel('Branch')
plt.ylabel('Unit Price ($)')
plt.show()

Next, if we want to investigate the relationships between multiple variables we can use the pair plot, in this case showing how unit price, quantity and total are related for different branches. We can see that higher unit prices and quantities lead to higher total sales.

sns.pairplot(df, hue='Branch', vars=['Unit price', 'Quantity', 'Total'])
plt.suptitle('Pairwise Relationships by Branch', y=1.02)
plt.show()

We can visualise 2-dimensional data using a heatmap plot, in this case visualising the correlation matrix between unit price, quantity and total:

corr = df[['Unit price', 'Quantity', 'Total']].corr()
plt.figure(figsize=(5, 4))
sns.heatmap(corr, annot=True, fmt=".2f", cmap='coolwarm', cbar=True)
plt.title('Correlation Matrix of Prices, Quantities, and Totals')
plt.show()

Finally, we can visualise a facet grid plot of some function of data for different values across the rows and columns of the plots. For example, here we visualise the histogram of total sales across different branches and payment methods:

g = sns.FacetGrid(df, col="Branch", row="Payment", margin_titles=True)
g.map(plt.hist, "Total", color="steelblue", bins=30)
g.set_axis_labels('Total', 'Frequency')
g.fig.suptitle('Total Sales Distribution by Branch and Payment Method', y=1.03)
plt.show()

So there you have it, an introduction to data preprocessing, data analysis and data visualisation for business. There are a lot more topics I could cover, such as more advanced data filtering and analysis and more complex plots with Seaborn, but as this is an introduction I wanted to keep it short and self-contained. I hope this gives you an idea of how to get started with doing basic data science in business!

Need help with a project or coursework? Check out my tutoring services here.

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

© 2024 Marian Longa. All rights reserved.

  • LinkedIn
  • Mail
  • GitHub

Privacy Policy

Terms and Conditions