Matplotlib

Matplotlib is used for creating static, interactive, and animated visualizations. It provides a wide range of plotting functions, allowing us to generate line plots, bar charts, histograms, scatter plots, and more.

Write the import matplotlib.pyplot as plt and/or import numpy as np statements at the beginning of every example.

Table of contents

Introduction

How to create a chart?
  1. Create the figure
  2. Plot the data (multiple times, if needed)
  3. Configure the axes
  4. Add annotations
  5. Show/save the figure

plt.figure() # creating the graph figure

# Setting the coordinates of the points to plot, their values, and labels for the legend
plt.plot([1, 2], [3, 4], label = "Python")
plt.plot([1, 3], [1, 2], label = "Java")

# Setting the limits (boundaries) of both axes
plt.xlim(0, 4)
plt.ylim(0, 5)

# Positions on the X axis where ticks should appear along with their optional labels
plt.xticks([0, 1, 2, 3, 4], ["zero", "one", "two", "three", "four"])
plt.yticks([0, 1, 2, 3, 4, 5])

# Labeling the axes
plt.xlabel("X")
plt.ylabel("Y")

plt.legend(loc = "upper center", ncol = 3) # setting the placement of the legend
plt.grid(True) # enabling the default major-grid lines

plt.show() # displaying the graph
# plt.savefig("graph.png") - saving the graph as an image
                                    

Custom styles

We can use default and custom styles by just enclosing the whole code from the example above in a with clause like this: with plt.style.context("ggplot"):. While ggplot is a built-in style, you can also define your own styles using, e.g., with plt.style.context("custom_style.mplstyle"):.


# the custom_style.mplstyle file

# Colors
axes.prop_cycle: cycler("color", ["#1b9e77", "#d95f02", "#7570b3", "#e7298a"])

# Figure
figure.figsize: 7, 4
figure.dpi: 120

# Axes
axes.titlesize: 14
axes.labelsize: 12
axes.grid: True
axes.spines.top: False
axes.spines.right: False

# Grid
grid.color: gray
grid.linestyle: -
grid.linewidth: 0.8
grid.alpha: 0.6

# Ticks
xtick.labelsize: 10
ytick.labelsize: 10

# Legend
legend.frameon: False
legend.fontsize: 10

# Lines
lines.linewidth: 2
lines.markersize: 6
                                    

Line plots

Line plots are used to visualize trends and continuous data over time or ordered sequences.


months = np.arange(1, 13)
sales = np.array([1200, 1350, 1500, 1600, 1800, 2000, 2200, 2100, 1900, 1700, 1600, 1800])

plt.figure(figsize = (10, 6))
plt.plot(months, sales, marker = "o", color = "blue", linestyle = "-", linewidth = 2, markersize = 8, label = "Sales") # marker is a symbol used to mark each data point on the line
plt.title("Monthly sales over a year")
plt.xlabel("Month")
plt.ylabel("Sales ($)")
plt.xticks(months) # show all months on the X axis
plt.grid(True)
plt.legend()
plt.show()
                                    
Chart

Bar charts

Bar charts are used to compare categorical data by showing values as rectangular bars.


years = np.arange(2020, 2023)
values1 = [1, 4, 8]
values2 = [2, 5, 9]

plt.figure()
# bar() - vertical bars, barh() - horizontal bars

plt.bar(years - 0.2, values1, color = "blue", edgecolor = "none", width = 0.4, align = "center", label = "y1")
plt.bar(years + 0.2, values2, color = "green", edgecolor = "none", width = 0.4, align = "center", label = "y2")
plt.xticks(years, [str(years) for years in years])

plt.legend()
plt.show()
                                    
Chart

Histograms

Histograms are used to display the distribution of numeric data by grouping values into bins.


np.random.seed(0)
scores = np.random.normal(loc = 75, scale = 10, size = 50) # mean = 75, std = 10

plt.figure(figsize = (8, 6))
plt.hist(scores, bins = 10, color = "skyblue", edgecolor = "black", alpha = 0.7) # bins is a number of intervals (bars) used to group the data
plt.title("Distribution of exam scores")
plt.xlabel("Score")
plt.ylabel("Number of students")
plt.grid(axis = "y", alpha = 0.75)
plt.show()
                                    
Chart

Pie charts

Pie charts are used to show proportions or percentages of a whole dataset.


counts = [13, 15, 20]
plt.figure()
plt.pie(counts, colors = ["green", "blue", "red"], labels = ["A", "B", "C"],
        startangle = 90, # the rotation angle (in degrees) where the pie chart starts
        autopct = "%1.1f%%") # the format for displaying percentage values on slices
plt.show()
                                    
Chart

Box charts

Box charts are used to summarize data distributions, highlighting medians, quartiles, and outliers.


np.random.seed(0) # setting a random seed to ensure reproducible results
class_a = np.random.normal(
    75, # mean (average exam score for class A)
    10, # standard deviation (how spread out the scores are)
    30 # number of students (data points to generate)
)
class_b = np.random.normal(80, 12, 30)
plt.figure(figsize = (8, 6)) # figsize is setting the width and height of the figure in inches

plt.boxplot(
    [class_a, class_b], # data to be plotted as separate boxplots
    tick_labels = ["Class A", "Class B"], # labels shown on the x-axis
    patch_artist = True, # allowing box colors to be filled
    boxprops = dict(
        facecolor = "lightblue", # fill color of the boxes
        color = "blue" # color of the box borders
    ),
    medianprops = dict(
        color = "red" # color of the median line
    ),
    whiskerprops = dict(
        color = "blue" # color of the whiskers
    ),
    capprops = dict(
        color = "blue" # color of the whisker caps
    )
)

plt.title("Exam score distribution by class")
plt.ylabel("Scores")

plt.grid(
    axis = "y", # enabling grid lines only along the y-axis
    alpha = 0.75 # setting grid line transparency
)

plt.show()
                                    
Chart

Scatter plots

Scatter plots are used to examine relationships or correlations between two numeric variables.


study_hours = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
exam_scores = np.array([50, 55, 60, 65, 70, 72, 78, 85, 88, 95])

plt.figure(figsize = (8, 6))
plt.scatter(study_hours, exam_scores, color = "blue", s = 80, alpha = 0.7, edgecolors = "black") # s is the size of points
plt.title("Relationship between study hours and exam scores")
plt.xlabel("Study hours")
plt.ylabel("Exam scores")
plt.grid(True)

z = np.polyfit(study_hours, exam_scores, 1) # degree 1 = linear fit
p = np.poly1d(z) # creating a polynomial function from fit
plt.plot(study_hours, p(study_hours), "r--", label="Trend line") # "r--" means a red dashed line
plt.legend()

plt.show()
                                    
Chart

Subplots

Subplots are used to display multiple plots in a single figure for comparison or combined visualization.


x = np.linspace(0.1, 10, 100) # creating 100 evenly spaced values between 0.1 and 10 (from 0.1 to avoid division by zero for cot)

plt.figure()

def draw_function(subplot_data, x, y, color, label, title):
    plt.subplot(*subplot_data) # subplot(nrows, ncols, index): 2 rows, 2 columns, first subplot position
    plt.plot(x, y, color, label = label)
    plt.title(title)
    if title == "Tangent" or title == "Cotangent":
        plt.ylim(-10, 10) # limit the Y axis for better visibility
    plt.legend()
    plt.grid(True)

draw_function([2, 2, 1], x, np.sin(x), "blue", "sin(x)", "Sine")
draw_function([2, 2, 2], x, np.cos(x), "red", "cos(x)", "Cosine")
draw_function([2, 2, 3], x, np.tan(x), "green", "tan(x)", "Tangent")
draw_function([2, 2, 4], x, 1 / np.tan(x), "purple", "cot(x)", "Cotangent")

plt.tight_layout() # automatically adjusting spacing to prevent overlapping elements
plt.show()
                                    
Chart

3D plots

3D plots are used to visualize data with three dimensions to show spatial relationships or surface structures.

3D contour plots

# Simulating temperature on a plate
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
x, y = np.meshgrid(x, y)
z = np.exp(-(x ** 2 + y ** 2)) # temperature distribution

fig = plt.figure()
ax = fig.add_subplot(111, projection = "3d")
ax.contour3D(x, y, z, 50, cmap = "coolwarm")

ax.set_xlabel("X coordinate")
ax.set_ylabel("Y coordinate")
ax.set_zlabel("Temperature")
ax.set_title("Heat distribution on a plate")

plt.show()
                                    
Chart
3D line plots (a helical spring)

# A parametric equation of a helix
t = np.linspace(0, 10 * np.pi, 500)
x = np.cos(t)
y = np.sin(t)
z = t / (2 * np.pi) # vertical axis

fig = plt.figure()
ax = fig.add_subplot(111, projection = "3d")
ax.plot(x, y, z, color = "blue", linewidth = 2)

ax.set_xlabel("X axis")
ax.set_ylabel("Y axis")
ax.set_zlabel("Height")
ax.set_title("Helical spring path")
                                    
Chart
3D scatter plots

# Simulating earthquake data (latitude, longitude, depth)
np.random.seed(0)
latitude = np.random.uniform(-90, 90, 100)
longitude = np.random.uniform(-180, 180, 100)
depth = np.random.uniform(0, 700, 100) # depth in km
magnitude = np.random.uniform(3, 8, 100) # magnitude for color

fig = plt.figure()
ax = fig.add_subplot(111, projection = "3d")
scatter = ax.scatter(longitude, latitude, -depth, c = magnitude, cmap = "hot", s = 50)

ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
ax.set_zlabel("Depth (km)")
ax.set_title("Earthquake epicenters (depth vs location)")

fig.colorbar(scatter, ax = ax, label = "Magnitude")
plt.show()
                                    
Chart
3D surface plots (a gradient descent visualization)

# Simulating a 3D quadratic loss function
x = np.linspace(-3, 3, 100)
y = np.linspace(-3, 3, 100)
x, y = np.meshgrid(x, y)
z = x ** 2 + y ** 2 # a simple convex loss function

# Gradient descent parameters
lr = 0.1
num_steps = 20
point = np.array([2.5, 2.5]) # the starting point
path = [point.copy()]

for _ in range(num_steps):
    grad = 2 * point
    point -= lr * grad
    path.append(point.copy())

path = np.array(path)
z_path = np.sum(path ** 2, axis = 1)

fig = plt.figure()
ax = fig.add_subplot(111, projection = "3d")
ax.plot_surface(x, y, z, cmap = "viridis", alpha = 0.6)
ax.plot(path[:, 0], path[:, 1], z_path, color = "red", marker = "o", label = "Gradient descent path")

ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Loss")
ax.set_title("Gradient descent on a quadratic loss")
ax.legend()
plt.show()
                                    
Chart