NumPy: Calculate the Julia Set with Vectorization

What will we cover in this tutorial?

In this tutorial you will learn what the Julia set is and understand how it is calculated. Also, how it translates into colorful images. In the process, we will learn how to utilize vectorization with NumPy arrays to achieve it.

Step 1: Understand the Julia set

Juila set are closely connect to the Mandelbrot set. If you are new to the Mandelbrot set, we recommend you read this tutorial before you proceed, as it will make it easier to understand.

Read this tutorial before if you are new to Mandelbrot and Julia sets.

Julia sets can be calculated for a function f. If we consider the function f_c(z) = z^2 + c, for a complex number c, then this function is used in the Mandelbrot set.

Recall the Mandelbrot set is calculated by identifying for a point c whether the function f_c(z) = z^2 + c , for which the sequence f_c(0), f_c(f_c(0)), f_c(f_c(f_c(0))), …., does not diverge.

Said differently, for each point c on the complex plane, if the sequence does not diverge, then that point is in the Mandelbrot set.

The Julia set has c fixed and and calculates the same sequence for z in the complex plane. That is, for each point z in the complex plane if the sequence f_c(0), f_c(f_c(0)), f_c(f_c(f_c(0))), …., does not diverge it is part of the Julia set.

Step 2: Pseudo code for Julia set of non-vectorization computation

The best way to understand is often to see the non-vectorization method to compute the Julia set.

As we consider the function f_c(z) = z^2 + c for our Julia set, we need to choose a complex number for c. Note, that complex number c can be set differently to get another Julia set.

Then each we can iterate over each point z in the complex plane.

c = -0.8 + i*0.34
for x in [-1, 1] do:
  for y in [-1, 1] do:
    z = x + i*y
    N = 0
    while absolute(z) < 2 and N < MAX_ITERATIONS:
      z = z^2 + c
    set color for x,y to N

This provides beautiful color images of the Julia set.

Julia set generated from the implementation below.

Step 3: The vectorization computation using NumPy arrays

How does that translate into code using NumPy?

import numpy as np
import matplotlib.pyplot as plt

def julia_set(c=-0.4 + 0.6j, height=800, width=1000, x=0, y=0, zoom=1, max_iterations=100):
    # To make navigation easier we calculate these values
    x_width = 1.5
    y_height = 1.5*height/width
    x_from = x - x_width/zoom
    x_to = x + x_width/zoom
    y_from = y - y_height/zoom
    y_to = y + y_height/zoom
    # Here the actual algorithm starts
    x = np.linspace(x_from, x_to, width).reshape((1, width))
    y = np.linspace(y_from, y_to, height).reshape((height, 1))
    z = x + 1j * y
    # Initialize z to all zero
    c = np.full(z.shape, c)
    # To keep track in which iteration the point diverged
    div_time = np.zeros(z.shape, dtype=int)
    # To keep track on which points did not converge so far
    m = np.full(c.shape, True, dtype=bool)
    for i in range(max_iterations):
        z[m] = z[m]**2 + c[m]
        m[np.abs(z) > 2] = False
        div_time[m] = i
    return div_time

plt.imshow(julia_set(), cmap='magma')
# plt.imshow(julia_set(x=0.125, y=0.125, zoom=10), cmap='magma')
# plt.imshow(julia_set(c=-0.8j), cmap='magma')
# plt.imshow(julia_set(c=-0.8+0.156j, max_iterations=512), cmap='magma')
# plt.imshow(julia_set(c=-0.7269 + 0.1889j, max_iterations=256), cmap='magma')
plt.show()
Generated from the code above.
Generated from the code above.

NumPy: Compute Mandelbrot set by Vectorization

What will we cover in this tutorial?

  • Understand what the Mandelbrot set it and why it is so fascinating.
  • Master how to make images in multiple colors of the Mandelbrot set.
  • How to implement it using NumPy vectorization.

Step 1: What is Mandelbrot?

Mandelbrot is a set of complex numbers for which the function f(z) = z^2 + c does not converge when iterated from z=0 (from wikipedia).

Take a complex number, c, then you calculate the sequence for N iterations:

z_(n+1) = z_n + c for n = 0, 1, …, N-1

If absolute(z_(N-1)) < 2, then it is said not to diverge and is part of the Mandelbrot set.

The Mandelbrot set is part of the complex plane, which is colored by numbers part of the Mandelbrot set and not.

Mandelbrot set.

This only gives a block and white colored image of the complex plane, hence often the images are made more colorful by giving it colors by the iteration number it diverged. That is if z_4 diverged for a point in the complex plane, then it will be given the color 4. That is how you end up with colorful maps like this.

Mandelbrot set (made by program from this tutorial).

Step 2: Understand the code of the non-vectorized approach to compute the Mandelbrot set

To better understand the images from the Mandelbrot set, think of the complex numbers as a diagram, where the real part of the complex number is x-axis and the imaginary part is y-axis (also called the Argand diagram).

Argand diagram

Then each point is a complex number c. That complex number will be given a color depending on which iteration it diverges (if it is not part of the Mandelbrot set).

Now the pseudocode for that should be easy to digest.

for x in [-2, 2] do:
  for y in [-1.5, 1.5] do:
    c = x + i*y
    z = 0
    N = 0
    while absolute(z) < 2 and N < MAX_ITERATIONS:
      z = z^2 + c
    set color for x,y to N

Simple enough to understand. That is some of the beauty of it. The simplicity.

Step 3: Make a vectorized version of the computations

Now we understand the concepts behind we should translate that into to a vectorized version. If you are new to vectorization we can recommend you read this tutorial first.

What do we achieve with vectorization? That we compute all the complex numbers simultaneously. To understand that inspect the initialization of all the points here.

import numpy as np
def mandelbrot(height, width, x_from=-2, x_to=1, y_from=-1.5, y_to=1.5, max_iterations=100):
    x = np.linspace(x_from, x_to, width).reshape((1, width))
    y = np.linspace(y_from, y_to, height).reshape((height, 1))
    c = x + 1j * y

You see that we initialize all the x-coordinates at once using the linespace. It will create an array with numbers from x_from to x_to in width points. The reshape is to fit the plane.

The same happens for y.

Then all the complex numbers are created in c = x + 1j*y, where 1j is the imaginary part of the complex number.

This leaves us to the full implementation.

There are two things we need to keep track of in order to make a colorful Mandelbrot set. First, in which iteration the point diverged. Second, to achieve that, we need to remember when a point diverged.

import numpy as np
import matplotlib.pyplot as plt

def mandelbrot(height, width, x=-0.5, y=0, zoom=1, max_iterations=100):
    # To make navigation easier we calculate these values
    x_width = 1.5
    y_height = 1.5*height/width
    x_from = x - x_width/zoom
    x_to = x + x_width/zoom
    y_from = y - y_height/zoom
    y_to = y + y_height/zoom
    # Here the actual algorithm starts
    x = np.linspace(x_from, x_to, width).reshape((1, width))
    y = np.linspace(y_from, y_to, height).reshape((height, 1))
    c = x + 1j * y
    # Initialize z to all zero
    z = np.zeros(c.shape, dtype=np.complex128)
    # To keep track in which iteration the point diverged
    div_time = np.zeros(z.shape, dtype=int)
    # To keep track on which points did not converge so far
    m = np.full(c.shape, True, dtype=bool)
    for i in range(max_iterations):
        z[m] = z[m]**2 + c[m]
        diverged = np.greater(np.abs(z), 2, out=np.full(c.shape, False), where=m) # Find diverging
        div_time[diverged] = i      # set the value of the diverged iteration number
        m[np.abs(z) > 2] = False    # to remember which have diverged
    return div_time

# Default image of Mandelbrot set
plt.imshow(mandelbrot(800, 1000), cmap='magma')
# The image below of Mandelbrot set
# plt.imshow(mandelbrot(800, 1000, -0.75, 0.0, 2, 200), cmap='magma')
# The image below of below of Mandelbrot set
# plt.imshow(mandelbrot(800, 1000, -1, 0.3, 20, 500), cmap='magma')
plt.show()

Notice that z[m] = z[m]**2 + c[m] only computes updates on values that are still not diverged.

I have added the following two images from above (the one not commented out is above in previous step.

Mandelbrot set from the program above.
Mandelbrot set from the code above.
Also check out the tutorial on Julia sets.

NumPy: How does Sexual Compulsivity Scale Correlate with Men, Women, or Age?

Background

According to wikipedia, the Sexual Compulsivity Scale (SCS) is a psychometric measure of high libido, hypersexuality, and sexual addiction. While it does not say anything about the score itself, it is based on people rating 10 questions from 1 to 4.

The questions are the following.

Q1. My sexual appetite has gotten in the way of my relationships.				
Q2. My sexual thoughts and behaviors are causing problems in my life.				
Q3. My desires to have sex have disrupted my daily life.				
Q4. I sometimes fail to meet my commitments and responsibilities because of my sexual behaviors.				
Q5. I sometimes get so horny I could lose control.				
Q6. I find myself thinking about sex while at work.				
Q7. I feel that sexual thoughts and feelings are stronger than I am.				
Q8. I have to struggle to control my sexual thoughts and behavior.				
Q9. I think about sex more than I would like to.				
Q10. It has been difficult for me to find sex partners who desire having sex as much as I want to.

The questions are rated as follows (1=Not at all like me, 2=Slightly like me, 3=Mainly like me, 4=Very much like me).

A dataset of more than 3300+ responses can be found here, which includes the individual rating of each questions, the total score (the sum of ratings), age and gender.

Step 1: First inspection of the data.

Inspection of the data (CSV file)

The first question that pops into my mind is how men and women rate themselves differently. How can we efficiently figure that out?

Welcome to NumPy. It has a built-in csv reader that does all the hard work in the genfromtxt function.

import numpy as np
data = np.genfromtxt('scs.csv', delimiter=',', dtype='int')
# Skip first row as it has description
data = data[1:]
men = data[data[:,11] == 1]
women = data[data[:,11] == 2]
print("Men average", men.mean(axis=0))
print("Women average", women.mean(axis=0))

Dividing into men and women is easy with NumPy, as you can make a vectorized conditional inside the dataset. Men are coded with 1 and women with 2 in column 11 (the 12th column). Finally, a call to mean will do the rest.

Men average [ 2.30544662  2.2453159   2.23485839  1.92636166  2.17124183  3.06448802
  2.19346405  2.28496732  2.43660131  2.54204793 23.40479303  1.
 32.54074074]
Women average [ 2.30959164  2.18993352  2.19088319  1.95916429  2.38746439  3.13010446
  2.18518519  2.2991453   2.4985755   2.43969611 23.58974359  2.
 27.52611586]

Interestingly, according to this dataset (which should be accounted for accuracy, where 21% of answers were not used) women are scoring slighter higher SCS than men.

Men rate highest on the following question:

Q6. I find myself thinking about sex while at work.

While women rate highest on this question.

Q6. I find myself thinking about sex while at work.

The same. Also the lowest is the same for both genders.

Q4. I sometimes fail to meet my commitments and responsibilities because of my sexual behaviors.

Step 2: Visualize age vs score

I would guess that the SCS score decreases with age. Let’s see if that is the case.

Again, NumPy can do the magic easily. That is prepare the data. To visualize it we use matplotlib, which is a comprehensive library for creating static, animated, and interactive visualizations in Python.

import numpy as np
import matplotlib.pyplot as plt
data = np.genfromtxt('scs.csv', delimiter=',', dtype='int')
# Skip first row as it has description
data = data[1:]
score = data[:,10]
age = data[:,12]
age[age > 100] = 0
plt.scatter(age, score, alpha=0.05)
plt.show()

Resulting in this plot.

Age vs SCS score.

It actually does not look like any correlation. Remember, there are more young people responding to the survey.

Let’s ask NumPy what it thinks about correlation here? Luckily we can do that by calling the corrcoef function, which calculates the Pearson product-moment correlation coefficients.

print("Correlation of age and SCS score:", np.corrcoef(age, score))

Resulting in this output.

Correlation of age and SCS score:
[[1.         0.01046882]
 [0.01046882 1.        ]]

Saying no correlation, as 0.0 – 0.3 is a small correlation, hence, 0.01046882 is close to none. Does that mean the the SCS score does not correlate with age? That our SCS score is static through life?

I do not think we can conclude that based on this small dataset.

Step 3: Bar plot the distribution of scores

It also looked like in the graph we plotted that there was a close to even distribution of scores.

Let’s try to see that. Here we need to sum participants by group. NumPy falls a bit short here. But let’s keep the good mood and use plain old Python lists.

import numpy as np
import matplotlib.pyplot as plt
data = np.genfromtxt('scs.csv', delimiter=',', dtype='int')
# Skip first row as it has description
data = data[1:]
scores = []
numbers = []
for i in range(10, 41):
    numbers.append(i)
    scores.append(data[data[:, 10] == i].shape[0])
plt.bar(numbers, scores)
plt.show()

Resulting in this bar plot.

Count participants by score.

We knew that the average score was around 23, which could give a potential evenly distribution. But it seems to be a little lower in the far high end of SCS score.

For another great tutorial on NumPy check this one out, or learn some differences between NumPy and Pandas.

NumPy: Analyse Narcissistic Personality Indicator Numerical Dataset

What is Narcissistic Personality Indicator and how does it connect to NumPy?

NumPy is an amazing library that makes analyzing data easy, especially numerical data.

In this tutorial we are going to analyze a survey with 11.000+ respondents from an interactive Narcissistic Personality Indicator (NPI) test.

Narcissism in personality trait generally conceived of as excessive self love. In Greek mythology Narcissus was a man who fell in love with his reflection in a pool of water.

https://openpsychometrics.org/tests/NPI/

The only connection between NPI and NumPy is that we want to analyze the 11.000+ answers.

The dataset can be downloaded here, which consists of a comma separated file, or CSV file for short and a description.

Step 1: Import the dataset and explore it

NumPy has thought of it for us, as simple as magic to load the dataset (in from the link above).

import numpy as np
# This magic line loads the 11.000+ lines of data to a ndarray
data = np.genfromtxt('data.csv', delimiter=',', dtype='int')
# Skip first row
data = data[1:]
print(data)

And we print a summary out.

[[ 18   2   2 ... 211   1  50]
 [  6   2   2 ... 149   1  40]
 [ 27   1   2 ... 168   1  28]
 ...
 [  6   1   2 ... 447   2  33]
 [ 12   2   2 ... 167   1  24]
 [ 18   1   2 ... 291   1  36]]

A good idea is to investigate it from a spreadsheet as well to investigate it.

Spreadsheet

And the far end.

Spreadsheet

Oh, that end.

Then investigate the description from the dataset. (Here we have some of it).

For questions 1=40 which choice they chose was recorded per the following key.
... [The questions Q1 ... Q40]
...
gender. Chosen from a drop down list (1=male, 2=female, 3=other; 0=none was chosen).
age. Entered as a free response. Ages below 14 have been ommited from the dataset.
-- CALCULATED VALUES --
elapse. (time submitted)-(time loaded) of the questions page in seconds.
score. = ((int) $_POST['Q1'] == 1)
... [How it is calculated]

That means we score, answers to questions, elapsed time to answer, gender and age.

Reading a bit more, it says that a high score is an indicator for having narcissistic traits, but one should not conclude that it is one.

Step 2: Men or Women highest NPI?

I’m glad you asked.

import numpy as np
data = np.genfromtxt('data.csv', delimiter=',', dtype='int')
# Skip first row
data = data[1:]
# Extract all the NPI scores (first column)
npi_score = data[:,0]
print("Average score", npi_score.mean())
print("Men average", npi_score[data[:,42] == 1].mean())
print("Women average", npi_score[data[:,42] == 2].mean())
print("None average", npi_score[data[:,42] == 0].mean())
print("Other average", npi_score[data[:,42] == 3].mean())

Before looking at the result, see how nice the data the first column is sliced out to the view in npi_score. Then notice how easy you can calculate the mean based on a conditional rules to narrow the view.

Average score 13.29965311749533
Men average 14.195953307392996
Women average 12.081829626521191
None average 11.916666666666666
Other average 14.85

I guess you guessed it. Men score higher.

Step 3: Is there a correlation between age and NPI score?

I wonder about that too.

How can we figure that out? Wait, let’s ask our new friend NumPy.

import numpy as np
import matplotlib.pyplot as plt
data = np.genfromtxt('data.csv', delimiter=',', dtype='int')
# Skip first row
data = data[1:]
# Extract all the NPI scores (first column)
npi_score = data[:,0]
age = data[:,43]
# Some age values are not real, so we adjust them to 0
age[age>100] = 0
# Scatter plot them all with alpha=0.05
plt.scatter(age, npi_score, color='r', alpha=0.05)
plt.show()

Resulting in.

Plotting age vs NPI

That looks promising. But can we just conclude that younger people score higher NPI?

What if most respondent are young, then that would make the picture more dense in the younger end (15-30). The danger with your eye is making fast conclusions.

Luckily, NumPy can help us there as well.

print(np.corrcoef(npi_score, age))

Resulting in.

Correlation of NPI score and age:
[[ 1.         -0.23414633]
 [-0.23414633  1.        ]]

What does that mean? Well, looking at the documentation of np.corroef():

Return Pearson product-moment correlation coefficients.

https://numpy.org/doc/stable/reference/generated/numpy.corrcoef.html

It has a negative correlation, which means that the younger the higher NPI score. Values between 0.0 and -0.3 are considered low.

Is the Pearson product-moment correlation the correct one to use?

Step 4: (Optional) Let’s try to see if there is a correlation between NPI score and time elapsed

Same code, different column.

import numpy as np
import matplotlib.pyplot as plt

data = np.genfromtxt('data.csv', delimiter=',', dtype='int')
# Skip first row
data = data[1:]
# Extract all the NPI scores (first column)
npi_score = data[:,0]
elapse = data[:,41]
elapse[elapse > 2000] = 2000
# Scatter plot them all with alpha=0.05
plt.scatter(elapse, npi_score, color='r', alpha=0.05)
plt.show()

Resulting in.

Time elapsed in seconds and NPI score

Again, it is tempting to conclude something here. We need to remember that the mean value is around 13, hence, most data will be around there.

If we use the same calculation.

print("Correlation of NPI score and time elapse:")
print(np.corrcoef(npi_score, elapse))

Output.

Correlation of NPI score and time elapse:
[[1.        0.0147711]
 [0.0147711 1.       ]]

Hence, here the there is close to no correlation.

Conclusion

Use the scientific tools to conclude. Do not rely on you eyes to determine whether there is a correlation.

The above gives an idea on how easy it is to work with numerical data in NumPy.

Deleting Elements of a Python List while Iterating

What will we cover in this tutorial?

  • Understand the challenge with deleting elements while iterating over a Python list.
  • How to delete element from a Python list while iterating over it.

Step 1: What happens when you just delete elements from a Python list while iterating over it?

Let’s first try this simple example to understand the challenge of deleting element in a Python list while iterating over it.

a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
for e in a:
  a.remove(e)
print(a)

Now, looking at this piece of code, it would seem to be intended to delete all elements. But that is not happening. See, the output is.

[1, 3, 5, 7, 9]

Seems like every second element is deleted. Right?

Let’s try to understand that. When we enter the the loop we see the following view.

for e (= 0, first element) in a (= [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]):
  a.remove(e)

Then the first element is removed on the second line, then the view is.

for e (= 0, first element) in a (= [1, 2, 3, 4, 5, 6, 7, 8, 9]):
  a.remove(e) (a = [1, 2, 3, 4, 5, 6, 7, 8, 9])

Going into the second iteration it looks like this.

for e (= 2, second element) in a (= [1, 2, 3, 4, 5, 6, 7, 8, 9]):
  a.remove(e)

Hence, we see that the iterator takes the second element, which now is the number 2.

This explains why the every second number is deleted from the list.

Step 2: What if we use index instead

Good idea. Let’s see what happens.

a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
for i, e in enumerate(a):
  a.pop(i)
print(a)

Which results in the same.

[1, 3, 5, 7, 9]

What if we iterate directly over the index by using the length of the list.

a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
for i in range(len(a)):
  a.pop(i)
print(a)

Oh, no.

Traceback (most recent call last):
  File "main.py", line 3, in <module>
    a.pop(i)
IndexError: pop index out of range

I get it. It is because the len(a) is invoked in the first iteration and results to 10. Then when we reach i = 5, we have already pop’ed 5 elements and have only 5 elements left. Hence, out of bound.

Not convinced?

a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
for i in range(len(a)):
  print(i, len(a), a)
  a.pop(i)
print(a)

Resulting to.

0 10 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
1 9 [1, 2, 3, 4, 5, 6, 7, 8, 9]
2 8 [1, 3, 4, 5, 6, 7, 8, 9]
3 7 [1, 3, 5, 6, 7, 8, 9]
4 6 [1, 3, 5, 7, 8, 9]
5 5 [1, 3, 5, 7, 9]
Traceback (most recent call last):
  File "main.py", line 4, in <module>
    a.pop(i)
IndexError: pop index out of range

But what to do?

Step 3: How to delete elements while iterating over a list

The problem we want to solve is not to delete all the element. It is to delete entries based on their values or some conditions, where we need to interpret the values of the elements.

How can we do that?

By using list comprehension or by making a copy. Or is it the same, as list comprehension is creating a new copy, right?

Okay, one step at the time. Just see the following example.

a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
a = [i for i in a if i % 2 == 0]
print(a)

Resulting in a copy of the the original list with only the even elements.

[0, 2, 4, 6, 8]

To see it is a copy you can evaluate the following code.

a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
b = a
a = [i for i in a if i % 2 == 0]
print(a)
print(b)

Resulting in the following, where you see the variable a get’s a new copy of it and the variable b refers to the original (and unmodified version).

[0, 2, 4, 6, 8]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Hence, the effect of the list comprehension construction above is as the following code shows.

a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# a = [i for i in a if i % 2 == 0]
c = []
for i in a:
    if i % 2 == 0:
        c.append(i)
a = c
print(a)

Getting the what you want.

[0, 2, 4, 6, 8]

Next steps

You can make the criteria more advanced by making the criteria by a function call.

def criteria(v):
  # some advanced code that returns True of False
a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
a = [i for i in a if criteria(i)]

And if you want to keep a state of all previous criteria, then you can even use an Object to keep that stored.

class State:
  # ...
  def criteria(self, v):
    # ...
s = State()
a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
a = [i for i in a if s.criteria(i)]

Also, check out this tutorial that makes some observations on performance on list comprehensions.

How to Reverse a Python List in-place and Understand Common Mistakes

Understand the difference

What does in-place mean? When reversing a list it means to not create a new list.

In-place reversing can be done by calling reverse().

a = [i for i in range(19)]
b = a
print("Before reverse")
print(a)
print(b)
a.reverse()
print("After reverse")
print(a)
print(b)

Will result in the following.

Before reverse
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
After reverse
[18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
[18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]

While using the slicing to get the list reversed will create a copy.

a = [i for i in range(19)]
b = a
print("Before reverse")
print(a)
print(b)
a = a[::-1] # Reverse using slicing
print("After reverse")
print(a)
print(b)

Resulting in.

Before reverse
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
After reverse
[18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]

It creates a new copy of the list a and keeps the original, which b points at, untouched.

Correct method: Swapping in a loop

Reversing a list in-place can be achieved with a call to reverse() as was shown above. A common question to get asked during a job-interview is to actually show an implementation of reversing.

As you might guess, you can achieve that by swapping elements.

[0, 1, 2, 3, 4, 5, 6]
# Swap first and last
[6, 1, 2, 3, 4, 5, 0]
# Swap second and second last
[6, 5, 2, 3, 4, 1, 0]
# Swap third and third last
[6, 5, 4, 3, 2, 1, 0]
# Swap the middle with, ah, itself (just kidding, this step is not needed)
[6, 5, 4, 3, 2, 1, 0]

How can you implement that? Let’s try.

a = [i for i in range(19)]
b = a
print(a)
print(b)
for i in range(len(a)//2):
    a[i], a[len(a) - i - 1] = a[len(a) - i - 1], a[i]
print(a)
print(b)

Resulting in the following expected output.

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
[18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
[18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]

Again notice, that I make b point to a, in order to ensure and convince you that we make the reversing in-place and not of a copy. If the print(b) would not print the identical output as print(a), we would have a problem to explain.

Pitfall: List comprehension (incorrect)

But wait a minute? Doesn’t list comprehension mean making a new list based on an existing list?

Correct!

But we can actually circumvent that by using some syntax (or can we? Recommend you read it all).

a = [i for i in range(19)]
b = a
print(a)
print(b)
a[:] = [a[len(a) - i - 1] for i in range(len(a))]
print(a)
print(b)

The slice assignment a[:] enforces Python to do assign, or override, the original values of a. Even if the output shows the following.

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
[18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
[18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]

The catch is that [a[len(a) – i – 1] for i in range(len(a))] creates a new list, but the slice assignment ensures to override the values of a.

Bummer.

Conclusion

It is easy to make false conclusions like the example above shows. One of the beauties of Python is the extreme power on a high level. The price is that many things are hard to understand unless you dive deep into it.

List Comprehension in Python made Easy with Comparisons

What will we cover in this tutorial?

  • Understand how list comprehension works in Python.
  • Updating and creation of new list is a memory aspect.
  • Test the performance difference between list comprehension and updating a list through a for-loop.

Step 1: Understand what is list comprehension

On wikipedia.org it is defined as follows.

list comprehension is a syntactic construct available in some programming languages for creating a list based on existing lists

Wikipedia.org

Then how does that translate into Python? Or is it, how does Python translate that into code?

If this is the first time you hear about list comprehension, but you have been programming for some time in Python and stumbled upon code pieces like this.

l1 = ['1', '2', '3', '4']
l2 = [int(s) for s in l1]
print(l2)

Which will result in a list of integers in l2.

[1, 2, 3, 4]

The construction for l2 is based on l1. Inspecting it closely, you can see a for-loop inside the creation of the square brackets. You could take the for-loop outside and have the same effect.

l1 = ['1', '2', '3', '4']
l2 = []
for s in l1:
  l2.append(int(s))
print(l2)

Nice.

Step 2: Updating and creation

Sometimes you see code like this.

l1 = [1, 2, 3, 4, 5, 6, 7]
l2 = [i + 1 for i in l1]
print(l2)

And you also notice that the l1 is not used after.

So what is the problem?

Let’s see an alternative way to do it.

l = [1, 2, 3, 4, 5, 6, 7]
for i in range(len(l)):
  l[i] += 1
print(l)

Which will result in the same effect. So what is the difference?

The first one, with list comprehension, creates a new list, while the second one updates the values of the list.

Not convinced? Investigate this piece of code.

def list_comprehension(l):
  return [i + 1 for i in l]
def update_loop(l):
  for i in range(len(l)):
    l[i] += 1
  return l
l1 = [1, 2, 3, 4, 5, 6, 7]
l2 = list_comprehension(l1)
print(l1, l2)
l1 = [1, 2, 3, 4, 5, 6, 7]
l2 = update_loop(l1)
print(l1, l2)

Which results in the following output.

[1, 2, 3, 4, 5, 6, 7] [2, 3, 4, 5, 6, 7, 8]
[2, 3, 4, 5, 6, 7, 8] [2, 3, 4, 5, 6, 7, 8]

As you see, the first one (list comprehension) creates a new list, while the other one updates the values in the existing.

From a memory perspective, this can be an issue with extremely large lists. But what about performance?

Step 3: Performance comparison between the two methods

This is interesting. To compare the run-time (performance) of the two functions we can use the cProfile standart Python library.

import cProfile
import random

def list_comprehension(l):
    return [i + 1 for i in l]

def update_loop(l):
    for i in range(len(l)):
        l[i] += 1
    return l

def test(n, it):
    l = [random.randint(0, n) for i in range(n)]
    for i in range(it):
        list_comprehension(l)
    l = [random.randint(0, n) for i in range(n)]
    for i in range(it):
        update_loop(l)

cProfile.run('test(10000, 100000)')

This results in the following output.

         152917 function calls in 16.837 seconds
   Ordered by: standard name
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   16.837   16.837 <string>:1(<module>)
        1    0.869    0.869   16.837   16.837 TEST.py:15(test)
        1    0.008    0.008    0.040    0.040 TEST.py:16(<listcomp>)
        1    0.003    0.003    0.023    0.023 TEST.py:20(<listcomp>)
    10000    0.013    0.000    4.739    0.000 TEST.py:5(list_comprehension)
    10000    4.726    0.000    4.726    0.000 TEST.py:6(<listcomp>)
    10000   11.164    0.001   11.166    0.001 TEST.py:9(update_loop)
    20000    0.019    0.000    0.041    0.000 random.py:200(randrange)
    20000    0.010    0.000    0.052    0.000 random.py:244(randint)
    20000    0.014    0.000    0.022    0.000 random.py:250(_randbelow_with_getrandbits)
        1    0.000    0.000   16.837   16.837 {built-in method builtins.exec}
    10000    0.002    0.000    0.002    0.000 {built-in method builtins.len}
    20000    0.002    0.000    0.002    0.000 {method 'bit_length' of 'int' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
    32911    0.006    0.000    0.006    0.000 {method 'getrandbits' of '_random.Random' objects}

Where we can see that the accumulated time spend in list_comprehension is 4.739 seconds, while the accumulated time spend in update_loop is 11.166 seconds.

Wait? Is it faster to create a new list than update an existing one?

Let’s do some more testing.

Performance of list comprehension vs updating a list

Seems to be no doubt about it.

Let’s just remember that Python is an interpreter and each instruction is highly optimized. Hence, keeping the code as list comprehension, can be highly optimized, while updating the loop is more flexible and takes more lines of interpretation.

Step 4 (Bonus): Use list comprehension with function

One aspect of list comprehension, is that it limits the possibility, while the for-loop construct is more flexible.

But wait, what if you use a function inside the list comprehension construction, then you should be able to regain a lot of that flexibility.

Let’s try to see how that affects the performance.

import cProfile
import random
def add_one(v):
  return v + 1
def list_comprehension(l):
    return [add_one(i) for i in l]

def update_loop(l):
    for i in range(len(l)):
        l[i] += 1
    return l

def test(n, it):
    l = [random.randint(0, n) for i in range(n)]
    for i in range(it):
        list_comprehension(l)
    l = [random.randint(0, n) for i in range(n)]
    for i in range(it):
        update_loop(l)

cProfile.run('test(1000, 10000)')

Giving the following output.

         10050065 function calls in 15.826 seconds
   Ordered by: standard name
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   15.826   15.826 <string>:1(<module>)
    10000    3.960    0.000    3.964    0.000 main.py:11(update_loop)
        1    0.296    0.296   15.826   15.826 main.py:17(test)
        1    0.001    0.001    0.005    0.005 main.py:18(<listcomp>)
        1    0.004    0.004    0.008    0.008 main.py:22(<listcomp>)
 10000000    4.389    0.000    4.389    0.000 main.py:4(add_one)
    10000    0.077    0.000   11.554    0.001 main.py:7(list_comprehension)
    10000    7.088    0.001   11.476    0.001 main.py:8(<listcomp>)
     2000    0.003    0.000    0.006    0.000 random.py:200(randrange)
     2000    0.002    0.000    0.008    0.000 random.py:244(randint)
     2000    0.002    0.000    0.003    0.000 random.py:250(_randbelow_with_getrandbits)
        1    0.000    0.000   15.826   15.826 {built-in method builtins.exec}
    10000    0.004    0.000    0.004    0.000 {built-in method builtins.len}
     2000    0.000    0.000    0.000    0.000 {method 'bit_length' of 'int' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
     2059    0.000    0.000    0.000    0.000 {method 'getrandbits' of '_random.Random' objects}

Oh no. It takes list_comprehension 11.554 seconds compared to update_loop 3.964 seconds.

This is obviously hard to optimize for the interpreter as it cannot predict the effect of the function (add_one). Adding that call in each iteration of the creation of the list add a big overhead in performance.

Conclusion

Can we conclude that list comprehension always beats updating an existing list? Not really. There is the memory dimension. If lists are big or memory is a sparse resource or you want to avoid too much memory cleanup by Python, then updating the list might be a better option.

Sort a Python List with String of Integers or a Mixture

What will we cover in this tutorial?

  • How can you sort a list of strings containing integers by the integer value?
  • Or what if it contains both strings containing integers and integers?
  • Finally, also how if only a substring contains integers?

Why sort on a list of integers represented in strings fails

First of, we need to understand why it is not trivial to solve by just calling sort on the list.

Let’s just try with an example.

l = ['4', '8', '12', '23', '4']
l.sort()
print(l)

Which will result in the following list.

['12', '23', '4', '4', '8']

Where you see the list is sorted lexicographical order and not by the numeric value the strings represent.

How to solve this

Solving this is quite straight forward if you know your way around Python. You look in the documentation and see that it takes a key as argument. Okay, you are new to this, so what does it mean.

key specifies a function of one argument that is used to extract a comparison key from each list element

Python docs.

Still not comfortable about it. Let’s try to figure it out together. If you are new to Python, you might not know that you can send functions as arguments like any other value.

The key argument is a function that will be applied on every item in the list. The output of that function will be used to make a simple comparison and order it by that.

That is great news. Why?

I am glad you asked. If we just use the int() function as argument, it should cast the string to an integer and use that for comparison and our problem is solved.

Let’s try.

l = ['4', '8', '12', '23', '4']
l.sort(key=int)
print(l)

Resulting to the following list.

['4', '4', '8', '12', '23']

How simple is that?

What if my list is a mixture of integers and strings of integers?

What is your wild guess?

l = ['4', '8', 12, '23', 4]
l.sort(key=int)
print(l)

Notice that some integers are not strings any more. Let see the output.

['4', 4, '8', 12, '23']

It works. This is why we love Python!

But what if it is more complex?

A complex examples of sorting

Say we have a list of of strings like this one.

l = ['4 dollars', '8 dollars', '12 dollars', '23 dollars', '4 dollars']

The story is something like this. You ask a lot of providers how much it will cost to give a specific service. The answers are given in the list and you want to investigate them in order of lowest price.

We can just do the same, right?

l = ['4 dollars', '8 dollars', '12 dollars', '23 dollars', '4 dollars']
l.sort(key=int)
print(l)

Wrong!

Traceback (most recent call last):
  File "main.py", line 2, in <module>
    l.sort(key=int)
ValueError: invalid literal for int() with base 10: '4 dollars'

The string is not just an integer. It contains more information.

The good luck is that we can send any function. Let’s try to create one.

def comp(o):
  return int(o.split()[0])
l = ['4 dollars', '8 dollars', '12 dollars', '23 dollars', '4 dollars']
l.sort(key=comp)
print(l)

And the output is as desired.

['4 dollars', '4 dollars', '8 dollars', '12 dollars', '23 dollars']

Too fast? Let’s just analyse our function comp. It contains only one return statement. Try to read it from inside out.

o.split() splits the string up in a list of items contain word by word. Hence, the call of ‘4 dollars’.split() will result in [‘4’, ‘dollars’].

Then o.split()[0] will return the first item of that list, i.e. ‘4’.

Finally, we cast it to an integer by int(o.split()[0]).

Remember that the comparison is done by the output of the function, that is what the function returns, which in this case is the integer represented by the first item in the string.

What about lambda?

Lambda? Yes, lambda functions is also a hot subject.

A lambda function is just a smart way to write simple functions you send as arguments to other functions. Like in this case a sorting function.

Let’s try if we can do that.

l = ['4 dollars', '8 dollars', '12 dollars', '23 dollars', '4 dollars']
l.sort(key=lambda o: int(o.split()[0]))
print(l)

Resulting in the same output.

['4 dollars', '4 dollars', '8 dollars', '12 dollars', '23 dollars']

A bit magic with lambda functions? We advice you to read this tutorial on the subject.

Pandas DataFrame Merge: Inner, Outer, Left, and Right

What will we cover in this tutorial?

A key process in Data Science is to merge data from various sources. This can be challenging and often needs clarity. Here we will take some simple example and explain the differences of how to merge data using the pandas library‘s DataFrame object merge function.

The key ways to merge is by inner, outer, left, and right.

In this example we are going to explore what correlates the most to GDP per capita: yearly meet consumption, yearly beer consumption, or long-term unemployment.

What is your educated guess? (no cheating, the result is down below)

Step 1: The data we want to merge

That means we need to gather the specified data.

The GDP per capita can be found on wikipedia.org. As we are going to do a lot of the same code again and again, let’s make a helper function to get the data, index the correct table, and drop the data we do not use in our analysis.

This can be done like this.

import pandas as pd

# This is simply used to display all the data and not get a small window of it
pd.set_option('display.max_rows', 300)
pd.set_option('display.max_columns', 15)
pd.set_option('display.width', 1000)

# This is a helper function, read the URL, get the right table, drop some columns
def read_table(url, table_number, drop_columns):
    tables = pd.read_html(url)
    table = tables[table_number]
    table = table.drop(drop_columns, axis=1)
    return table

# GDP per capita
url = 'https://en.wikipedia.org/wiki/List_of_countries_by_GDP_(nominal)_per_capita'
table = read_table(url, 3, ['Rank'])
table.rename(columns={'Country/Territory': 'Country'}, inplace=True)
print(table)

Which results in this output (or the few first lines of it).

                                    Country     US$
0                             Monaco (2018)  185741
1                      Liechtenstein (2017)  173356
2                                Luxembourg  114705
3                                     Macau   84096
4                               Switzerland   81994
5                                   Ireland   78661
6                                    Norway   75420
7                                   Iceland   66945

Comparing this to wikipedia.org.

From wikipedia.org

We can identify that this is the middle GDP, based on the World Bank.

Then we need data from the other sources. Here we get it for long-term unemployment (long-term unemployment is defined to be unemployed for 1 year or more).

# Long-term unemployement
url = 'https://en.wikipedia.org/wiki/List_of_OECD_countries_by_long-term_unemployment_rate'
table_join = read_table(url, 0, ['Long-term unemployment rate (2012)[1]'])
table_join.rename(columns={'Country/Territory': 'Country', 'Long-term unemployment rate (2016)[1]': 'Long-term unemployment rate'}, inplace=True)
index = 'Long-term unemployment rate'
table_join[index] = table_join[index].str[:-1].astype(float)
print(table_join)

Resulting in the following output

           Country  Long-term unemployment rate
0        Australia                         1.32
1          Austria                         1.53
2          Belgium                         4.26
3           Brazil                         0.81
4           Canada                         0.89
5            Chile                         1.67
6   Czech Republic                         2.72
7          Denmark                         1.66

This can be done for the two other dimensions we want to explore as well. We will skip it here, as the full code comes later.

Step 2: Simple merge it together

What happens if we merge the data together without considering which type or merge?

Skip reading the documentation also. Let’s just do it.

import pandas as pd

# This is simply used to display all the data and not get a small window of it
pd.set_option('display.max_rows', 300)
pd.set_option('display.max_columns', 15)
pd.set_option('display.width', 1000)

# This is a helper function, read the URL, get the right table, drop some columns
def read_table(url, table_number, drop_columns):
    tables = pd.read_html(url)
    table = tables[table_number]
    table = table.drop(drop_columns, axis=1)
    return table

# GDP per capita
url = 'https://en.wikipedia.org/wiki/List_of_countries_by_GDP_(nominal)_per_capita'
table = read_table(url, 3, ['Rank'])
table.rename(columns={'Country/Territory': 'Country'}, inplace=True)
# Long-term unemployement
url = 'https://en.wikipedia.org/wiki/List_of_OECD_countries_by_long-term_unemployment_rate'
table_join = read_table(url, 0, ['Long-term unemployment rate (2012)[1]'])
table_join.rename(columns={'Country/Territory': 'Country', 'Long-term unemployment rate (2016)[1]': 'Long-term unemployment rate'}, inplace=True)
index = 'Long-term unemployment rate'
table_join[index] = table_join[index].str[:-1].astype(float)
table = pd.merge(table, table_join)
# Meat consumption
url = 'https://en.wikipedia.org/wiki/List_of_countries_by_meat_consumption'
table_join = read_table(url, 1, ['Kg/person (2009)[10]'])
table_join.rename(columns={'Kg/person (2002)[9][note 1]': 'Kg meat/person'}, inplace=True)
table = pd.merge(table, table_join)
# Beer consumption
url = 'https://en.wikipedia.org/wiki/List_of_countries_by_beer_consumption_per_capita'
table_join = read_table(url, 2, ['2018change(litres per year)', 'Total nationalconsumption[a](million litresper year)', 'Year', 'Sources'])
table_join.rename(columns={'Consumptionper capita[1](litres per year)': 'Liter beer/person'}, inplace=True)
table = pd.merge(table, table_join)

print(table)
# Calculate the correlation
table_corr = table.corr()
# Print the correlation to GDP per capita (stored in US$).
print(table_corr['US$'].sort_values(ascending=False))

Which result in the output from the first print statement to be (this is the full output).

           Country    US$  Long-term unemployment rate  Kg meat/person  Liter beer/person
0      Switzerland  81994                         1.71            72.9               55.5
1          Ireland  78661                         6.68           106.3               95.8
2          Denmark  59822                         1.66           145.9               59.6
3        Australia  54907                         1.32           108.2               76.3
4      Netherlands  52448                         2.98            89.3               78.1
5          Austria  50277                         1.53            94.1              107.6
6          Finland  48686                         1.97            67.4               76.7
7          Germany  46259                         2.21            82.1              101.1
8           Canada  46195                         0.89           108.1               55.7
9          Belgium  46117                         4.26            86.1               67.0
10          Israel  43641                         0.63            97.1               17.4
11  United Kingdom  42300                         2.22            79.6               72.9
12     New Zealand  42084                         0.78           142.1               65.5
13          France  40494                         4.21           101.1               33.0
14           Japan  40247                         1.36            45.9               41.4
15           Italy  33190                         7.79            90.4               31.0
16           Spain  29614                        12.92           118.6               86.0
17        Slovenia  25739                         5.27            88.0               80.2
18  Czech Republic  23102                         2.72            77.3              191.8
19        Slovakia  19329                         8.80            67.4               83.5
20         Hungary  16476                         3.78           100.7               76.8
21          Poland  15595                         3.26            78.1               98.2
22          Mexico   9863                         0.06            58.6               68.7
23          Turkey   9043                         2.04            19.3               13.0
24          Brazil   8717                         0.81            82.4               60.0

Strange, you might think? There is only 25 countries (counting from 0). Also, let’s look at the actual correlation between columns, which is the output of the second print statement.

S$                            1.000000
Kg meat/person                 0.392070
Liter beer/person             -0.021863
Long-term unemployment rate   -0.086968
Name: US$, dtype: float64

Correlations are quite low. It correlates the most with meat, but still not that much.

Step 3: Let’s read the types of merge available

Reading the documentation of merge, you will notice there are four types of merge.

  • left: use only keys from left frame, similar to a SQL left outer join; preserve key order.
  • right: use only keys from right frame, similar to a SQL right outer join; preserve key order.
  • outer: use union of keys from both frames, similar to a SQL full outer join; sort keys lexicographically.
  • inner: use intersection of keys from both frames, similar to a SQL inner join; preserve the order of the left keys.

We also see that inner merge is the default. So what does inner merge do?

It means, it will only merge on keys which exists for both DataFrames. Translated to our tables, it means, that the only remaining rows in the final merged table is the ones which exists for all 4 tables.

You can check that, it is the 25 countries listed there.

Step 4: Understand what we should do

What we are doing in the end is correlate to the GDP per capita. Hence, it only makes sense to keep the values that have a GDP.

Consider we used outer merge, then we will keep all combinations. That would not give any additional value to the calculations we want to do.

But let’s just try it and investigate the output.

import pandas as pd

# This is simply used to display all the data and not get a small window of it
pd.set_option('display.max_rows', 300)
pd.set_option('display.max_columns', 15)
pd.set_option('display.width', 1000)

# This is a helper function, read the URL, get the right table, drop some columns
def read_table(url, table_number, drop_columns):
    tables = pd.read_html(url)
    table = tables[table_number]
    table = table.drop(drop_columns, axis=1)
    return table

# GDP per capita
url = 'https://en.wikipedia.org/wiki/List_of_countries_by_GDP_(nominal)_per_capita'
table = read_table(url, 3, ['Rank'])
table.rename(columns={'Country/Territory': 'Country'}, inplace=True)
# Long-term unemployement
url = 'https://en.wikipedia.org/wiki/List_of_OECD_countries_by_long-term_unemployment_rate'
table_join = read_table(url, 0, ['Long-term unemployment rate (2012)[1]'])
table_join.rename(columns={'Country/Territory': 'Country', 'Long-term unemployment rate (2016)[1]': 'Long-term unemployment rate'}, inplace=True)
index = 'Long-term unemployment rate'
table_join[index] = table_join[index].str[:-1].astype(float)
table = pd.merge(table, table_join, how='outer')
# Meat consumption
url = 'https://en.wikipedia.org/wiki/List_of_countries_by_meat_consumption'
table_join = read_table(url, 1, ['Kg/person (2009)[10]'])
table_join.rename(columns={'Kg/person (2002)[9][note 1]': 'Kg meat/person'}, inplace=True)
table = pd.merge(table, table_join, how='outer')
# Beer consumption
url = 'https://en.wikipedia.org/wiki/List_of_countries_by_beer_consumption_per_capita'
table_join = read_table(url, 2, ['2018change(litres per year)', 'Total nationalconsumption[a](million litresper year)', 'Year', 'Sources'])
table_join.rename(columns={'Consumptionper capita[1](litres per year)': 'Liter beer/person'}, inplace=True)
table = pd.merge(table, table_join, how='outer')

print(table)
# Calculate the correlation
table_corr = table.corr()
# Print the correlation to GDP per capita (stored in US$).
print(table_corr['US$'].sort_values(ascending=False))

First of all, this keeps all the output. I will not put it here, but only show a few lines.

                                    Country       US$  Long-term unemployment rate  Kg meat/person  Liter beer/person
0                             Monaco (2018)  185741.0                          NaN             NaN                NaN
1                      Liechtenstein (2017)  173356.0                          NaN             NaN                NaN
2                                Luxembourg  114705.0                         1.60           141.7                NaN
222                United States of America       NaN                          NaN           124.8                NaN
223            United States Virgin Islands       NaN                          NaN             6.6                NaN
224                               Venezuela       NaN                          NaN            56.6                NaN
225                                  Taiwan       NaN                          NaN             NaN               23.2

As the sample lines above shows, we get a row if one of them column is not NaN. Before when we used inner we would only get lines when all columns were not NaN.

The output of the correlation is now.

US$                            1.000000
Kg meat/person                 0.706692
Liter beer/person              0.305120
Long-term unemployment rate   -0.249958
Name: US$, dtype: float64

This is different values than from the previous example. Surprised? Not really. Now we have more data to correlate.

Step 5: Do the correct thing

If we inspect the code, we can see that the we start by having the GDP table on the left side. This growing table is always kept on the left side. Hence, we should be able to merge with left. Notice that this should not affect the final result.

Let’s try it.

import pandas as pd

# This is simply used to display all the data and not get a small window of it
pd.set_option('display.max_rows', 300)
pd.set_option('display.max_columns', 15)
pd.set_option('display.width', 1000)

# This is a helper function, read the URL, get the right table, drop some columns
def read_table(url, table_number, drop_columns):
    tables = pd.read_html(url)
    table = tables[table_number]
    table = table.drop(drop_columns, axis=1)
    return table

# GDP per capita
url = 'https://en.wikipedia.org/wiki/List_of_countries_by_GDP_(nominal)_per_capita'
table = read_table(url, 3, ['Rank'])
table.rename(columns={'Country/Territory': 'Country'}, inplace=True)
# Long-term unemployement
url = 'https://en.wikipedia.org/wiki/List_of_OECD_countries_by_long-term_unemployment_rate'
table_join = read_table(url, 0, ['Long-term unemployment rate (2012)[1]'])
table_join.rename(columns={'Country/Territory': 'Country', 'Long-term unemployment rate (2016)[1]': 'Long-term unemployment rate'}, inplace=True)
index = 'Long-term unemployment rate'
table_join[index] = table_join[index].str[:-1].astype(float)
table = pd.merge(table, table_join, how='left')
# Meat consumption
url = 'https://en.wikipedia.org/wiki/List_of_countries_by_meat_consumption'
table_join = read_table(url, 1, ['Kg/person (2009)[10]'])
table_join.rename(columns={'Kg/person (2002)[9][note 1]': 'Kg meat/person'}, inplace=True)
table = pd.merge(table, table_join, how='left')
# Beer consumption
url = 'https://en.wikipedia.org/wiki/List_of_countries_by_beer_consumption_per_capita'
table_join = read_table(url, 2, ['2018change(litres per year)', 'Total nationalconsumption[a](million litresper year)', 'Year', 'Sources'])
table_join.rename(columns={'Consumptionper capita[1](litres per year)': 'Liter beer/person'}, inplace=True)
table = pd.merge(table, table_join, how='left')

print(table)
# Calculate the correlation
table_corr = table.corr()
# Print the correlation to GDP per capita (stored in US$).
print(table_corr['US$'].sort_values(ascending=False))

Resulting in the same final print statement.

US$                            1.000000
Kg meat/person                 0.706692
Liter beer/person              0.305120
Long-term unemployment rate   -0.249958
Name: US$, dtype: float64

Question: What does the data tell us?

Good question. What does our finding tell us? Let’s inspect the final output.

US$                            1.000000
Kg meat/person                 0.706692
Liter beer/person              0.305120
Long-term unemployment rate   -0.249958
Name: US$, dtype: float64

The row with US$ shows the full correlation to GDP per capita, which obviously has 100% (1.00000) correlation to GDP per capita, as it is the number itself.

The second row tells us that eating a lot of meat is highly correlated to GDP per capita. Does that then mean that a country should encourage all citizens to eat more meat to become richer? No, you cannot conclude that. It is probably the other way around. The richer a country is, the more meat they eat.

The last line tells us that long-term unemployment is negative related to GDP per capita. It is not surprising. It means, the more long-term unemployed people, the less GDP per capita. But it is not highly correlated, only (approximately) -25%.

Surprisingly, it seems to have bigger positive impact to drink a lot of beer, then it has negative impact of long-term unemployment.

What a wonderful world.

Pandas: Data Preparation with Vectorized Strings vs Lambda Functions

What will we cover in this tutorial?

Understand the challenge

Most of the time when you read data into a pandas DataFrame it need to be prepared.

To be more concrete, let’s look at an example. Let’s consider we want to look at the List of largest companies by revenue on Wikipedia.

From wikipedia.org

You can read an inspect the data by the following code by using a DataFrame from the pandas library.

import pandas as pd
pd.set_option('display.max_rows', 300)
pd.set_option('display.max_columns', 10)
pd.set_option('display.width', 1000)

url = 'https://en.wikipedia.org/wiki/List_of_largest_companies_by_revenue'
tables = pd.read_html(url)
table = tables[0]
print(table)
print(table.dtypes)

Notice that we use pd.set_option calls to get the full view. If you are new to read_html from the pandas library, we recommend you read this tutorial.

The top of the output will be as follows.

    Rank                        Name                     Industry Revenue(USD millions) Profit(USD millions)  Employees                       Country   Ref
0      1                     Walmart                       Retail              $514,405               $6,670    2200000                 United States   [5]
1      2               Sinopec Group                  Oil and gas              $414,650               $5,845     619151                         China   [6]
2      3           Royal Dutch Shell                  Oil and gas              $396,556              $23,352      81000  Netherlands / United Kingdom   [7]
3      4    China National Petroleum                  Oil and gas              $392,976               $2,270    1382401                         China   [8]
4      5                  State Grid                  Electricity              $387,056               $8,174     917717                         China   [9]
5      6                Saudi Aramco                  Oil and gas              $355,905             $110,974      76418                  Saudi Arabia  [10]
6      7                          BP                  Oil and gas              $303,738               $9,383      73000                United Kingdom  [11]
7      8                  ExxonMobil                  Oil and gas              $290,212              $20,840      71000                 United States  [12]
8      9                  Volkswagen                   Automotive              $278,341              $14,332     664496                       Germany  [13]

And the last lines.

Rank                      int64
Name                     object
Industry                 object
Revenue(USD millions)    object
Profit(USD millions)     object
Employees                 int64
Country                  object
Ref                      object
dtype: object

Where we see interesting information about what data types each column has. Not surprisingly, the Revenue and Profit columns are of type object (which are strings in this case).

Hence, if we want to sum up values, we need to transform them to floats. This is a bit tricky, as the output shows above. An example is $6,670, where there are two issues to transform them to floats. First, there is a dollars ($) sign in the beginning. Second, there is comma (,) in the number, which a simple cast to float does not handle.

Now let’s deal with them in each their method.

Method 1: Using pandas DataFrame/Series vectorized string functions

Vectorization with pandas data structures is the process of executing operations on entire data structure. This is handy, as the alternative would be to make a loop-function.

Also, the pandas has many string functions available for vectorization as you can see in the documentation.

First of, we can access the string object by using the .str, then we can apply the string function. In our case, we will use the substring with square brackets to remove the dollar sign.

index_r = 'Revenue(USD millions)'
table[index_r] = table[index_r].str[1:]

Which will give the following output.

    Rank                        Name                     Industry Revenue(USD millions) Profit(USD millions)  Employees                       Country   Ref
0      1                     Walmart                       Retail               514,405               $6,670    2200000                 United States   [5]
1      2               Sinopec Group                  Oil and gas               414,650               $5,845     619151                         China   [6]
2      3           Royal Dutch Shell                  Oil and gas               396,556              $23,352      81000  Netherlands / United Kingdom   [7]
3      4    China National Petroleum                  Oil and gas               392,976               $2,270    1382401                         China   [8]
4      5                  State Grid                  Electricity               387,056               $8,174     917717                         China   [9]
5      6                Saudi Aramco                  Oil and gas               355,905             $110,974      76418                  Saudi Arabia  [10]
6      7                          BP                  Oil and gas               303,738               $9,383      73000                United Kingdom  [11]
7      8                  ExxonMobil                  Oil and gas               290,212              $20,840      71000                 United States  [12]
8      9                  Volkswagen                   Automotive               278,341              $14,332     664496                       Germany  [13]

Then we need to remove the comma (,). This can be done by using replace.

index_r = 'Revenue(USD millions)'
table[index_r] = table[index_r].str[1:] .str.replace(',', '')

Which will result in the following output.

    Rank                        Name                     Industry Revenue(USD millions) Profit(USD millions)  Employees                       Country   Ref
0      1                     Walmart                       Retail                514405               $6,670    2200000                 United States   [5]
1      2               Sinopec Group                  Oil and gas                414650               $5,845     619151                         China   [6]
2      3           Royal Dutch Shell                  Oil and gas                396556              $23,352      81000  Netherlands / United Kingdom   [7]
3      4    China National Petroleum                  Oil and gas                392976               $2,270    1382401                         China   [8]
4      5                  State Grid                  Electricity                387056               $8,174     917717                         China   [9]
5      6                Saudi Aramco                  Oil and gas                355905             $110,974      76418                  Saudi Arabia  [10]
6      7                          BP                  Oil and gas                303738               $9,383      73000                United Kingdom  [11]
7      8                  ExxonMobil                  Oil and gas                290212              $20,840      71000                 United States  [12]
8      9                  Volkswagen                   Automotive                278341              $14,332     664496                       Germany  [13]

Finally, we need to convert the string to a float.

index_r = 'Revenue(USD millions)'
table[index_r] = table[index_r].str[1:] .str.replace(',', '').astype(float)

Which does not change the printed output, but the type of the column.

Nice and easy, to prepare the data in one line. Notice, that you could chose to make it in multiple lines. It is a matter of taste.

Method 2: Using pandas DataFrame lambda function

Another way to prepare data is by using a lambda function. If you are new to lambda functions, we recommend you read this tutorial.

Here you can do it row by row and apply your defined lambda function.

The next column has the same challenge as the first one. So let’s apply it on that.

In this case, we cannot use the substring with square brackets like in the case above, as some figures are negative and contain that minus sign before the dollar sign. But using the replace call will do fine.

index_p = 'Profit(USD millions)'
table[index_p] = table.apply(lambda row: row[index_p].replace('$', ''), axis=1)

Which would result in the following output.

    Rank                        Name                     Industry  Revenue(USD millions) Profit(USD millions)  Employees                       Country   Ref
0      1                     Walmart                       Retail               514405.0                6,670    2200000                 United States   [5]
1      2               Sinopec Group                  Oil and gas               414650.0                5,845     619151                         China   [6]
2      3           Royal Dutch Shell                  Oil and gas               396556.0               23,352      81000  Netherlands / United Kingdom   [7]
3      4    China National Petroleum                  Oil and gas               392976.0                2,270    1382401                         China   [8]
4      5                  State Grid                  Electricity               387056.0                8,174     917717                         China   [9]
5      6                Saudi Aramco                  Oil and gas               355905.0              110,974      76418                  Saudi Arabia  [10]
6      7                          BP                  Oil and gas               303738.0                9,383      73000                United Kingdom  [11]
7      8                  ExxonMobil                  Oil and gas               290212.0               20,840      71000                 United States  [12]
8      9                  Volkswagen                   Automotive               278341.0               14,332     664496                       Germany  [13]

Then we do the same to remove the comma (,).

index_p = 'Profit(USD millions)'
table[index_p] = table.apply(lambda row: row[index_p].replace('$', '').replace(',', ''), axis=1)

Which result in the following output.

    Rank                        Name                     Industry  Revenue(USD millions) Profit(USD millions)  Employees                       Country   Ref
0      1                     Walmart                       Retail               514405.0                 6670    2200000                 United States   [5]
1      2               Sinopec Group                  Oil and gas               414650.0                 5845     619151                         China   [6]
2      3           Royal Dutch Shell                  Oil and gas               396556.0                23352      81000  Netherlands / United Kingdom   [7]
3      4    China National Petroleum                  Oil and gas               392976.0                 2270    1382401                         China   [8]
4      5                  State Grid                  Electricity               387056.0                 8174     917717                         China   [9]
5      6                Saudi Aramco                  Oil and gas               355905.0               110974      76418                  Saudi Arabia  [10]
6      7                          BP                  Oil and gas               303738.0                 9383      73000                United Kingdom  [11]
7      8                  ExxonMobil                  Oil and gas               290212.0                20840      71000                 United States  [12]
8      9                  Volkswagen                   Automotive               278341.0                14332     664496                       Germany  [13]

Finally, we will do the same for casting it to a float.

index_p = 'Profit(USD millions)'
table[index_p] = table.apply(lambda row: float(row[index_p].replace('$', '').replace(',', '')), axis=1)

Which will produce the same output.

Comparing the two methods

To be honest, it is a matter of taste in this case. When things can be achieved by simple string manipulation calls that are available through the vectorized calls, there is nothing to gain by lambda functions.

The strength of lambda functions is the flexibility. You can actually do anything function in there, which is a big strength. The vectorized functions are limited to simple operations, which covers a lot of use cases.

Putting it all together

Well, now we came so far, let’s put it all together and get some nice data. Sum it up and print it sorted out and make a horizontal bar plot.

import pandas as pd
import matplotlib.pyplot as plt
pd.set_option('display.max_rows', 300)
pd.set_option('display.max_columns', 10)
pd.set_option('display.width', 1000)

url = 'https://en.wikipedia.org/wiki/List_of_largest_companies_by_revenue'
tables = pd.read_html(url)
table = tables[0]
index_r = 'Revenue(USD millions)'
table[index_r] = table[index_r].str[1:] .str.replace(',', '').astype(float)
index_p = 'Profit(USD millions)'
table[index_p] = table.apply(lambda row: float(row[index_p].replace('$', '').replace(',', '')), axis=1)
table = table.drop(['Rank'], axis=1)

print(table.groupby('Country').sum().sort_values([index_r, index_p], ascending=False))
table.groupby('Industry').sum().sort_values([index_r, index_p], ascending=False).plot.barh()
plt.show()

The output graph.

Resulting output graph.

And the output from the program.

                              Revenue(USD millions)  Profit(USD millions)  Employees
Country                                                                             
United States                             4169049.0              243970.0    6585076
China                                     2263521.0              182539.0    5316321
Germany                                    602635.0               31693.0    1105639
Japan                                      561157.0               27814.0     670636
Netherlands / United Kingdom               396556.0               23352.0      81000
Saudi Arabia                               355905.0              110974.0      76418
France                                     309684.0               13971.0     208525
United Kingdom                             303738.0                9383.0      73000
Russia                                     250447.0               33062.0     568600
South Korea                                221579.0               39895.0     221579
Switzerland                                219754.0                3408.0      85504
Singapore                                  180744.0                 849.0       4316
Taiwan                                     175617.0                4281.0     667680
Netherlands                                175009.0                1589.0     314790