Coding Example: The Mandelbrot Set (NumPy approach)
In this lesson, we are going to look at two NumPy approaches to solve this case study!
We'll cover the following...
Solution 1: NumPy Implementation
The trick is to search at each iteration values that have not yet diverged and update relevant information for these values and only these values. Because we start from Z = 0, we know that each value will be updated at least once (when they’re equal to 0, they have not yet diverged) and will stop being updated as soon as they’ve diverged. To do that, we’ll use NumPy fancy indexing with the less(x1,x2) function that return the truth value of (x1 < x2) element-wise.
Press + to interact
def mandelbrot_numpy(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon=2.0):X = np.linspace(xmin, xmax, xn, dtype=np.float32)Y = np.linspace(ymin, ymax, yn, dtype=np.float32)C = X + Y[:,None]*1jN = np.zeros(C.shape, dtype=int)Z = np.zeros(C.shape, np.complex64)for n in range(maxiter):I = np.less(abs(Z), horizon)N[I] = nZ[I] = Z[I]**2 + C[I]N[N == maxiter-1] = 0return Z, N
Now lets replace the python approach with this one and see what happens:
Press + to interact
main.py
tools.py
# -----------------------------------------------------------------------------# From Numpy to Python# Copyright (2017) Nicolas P. Rougier - BSD license# More information at https://github.com/rougier/numpy-book# -----------------------------------------------------------------------------import numpy as npdef mandelbrot_numpy1(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon=2.0):# Adapted from https://www.ibm.com/developerworks/community/blogs/jfp/...# .../entry/How_To_Compute_Mandelbrodt_Set_Quickly?lang=enX = np.linspace(xmin, xmax, xn, dtype=np.float32)Y = np.linspace(ymin, ymax, yn, dtype=np.float32)C = X + Y[:,None]*1jN = np.zeros(C.shape, dtype=int)Z = np.zeros(C.shape, np.complex64)for n in range(maxiter):I = np.less(abs(Z), horizon)N[I] = nZ[I] = Z[I]**2 + C[I]N[N == maxiter-1] = 0return Z, Ndef mandelbrot(xmin, xmax, ymin, ymax, xn, yn, itermax, horizon=2.0):# Adapted from# https://thesamovar.wordpress.com/2009/03/22/fast-fractals-with-python-and-numpy/Xi, Yi = np.mgrid[0:xn, 0:yn]Xi, Yi = Xi.astype(np.uint32), Yi.astype(np.uint32)X = np.linspace(xmin, xmax, xn, dtype=np.float32)[Xi]Y = np.linspace(ymin, ymax, yn, dtype=np.float32)[Yi]C = X + Y*1jN_ = np.zeros(C.shape, dtype=np.uint32)Z_ = np.zeros(C.shape, dtype=np.complex64)Xi.shape = Yi.shape = C.shape = xn*ynZ = np.zeros(C.shape, np.complex64)for i in range(itermax):if not len(Z): break# Compute for relevant points onlynp.multiply(Z, Z, Z)np.add(Z, C, Z)# Failed convergenceI = abs(Z) > horizonN_[Xi[I], Yi[I]] = i+1Z_[Xi[I], Yi[I]] = Z[I]# Keep going with those who have not diverged yetnp.logical_not(I,I)Z = Z[I]Xi, Yi = Xi[I], Yi[I]C = C[I]return Z_.T, N_.Tif __name__ == '__main__':from matplotlib import colorsimport matplotlib.pyplot as pltfrom tools import timeit# Benchmarkxmin, xmax, xn = -2.25, +0.75, int(3000/3)ymin, ymax, yn = -1.25, +1.25, int(2500/3)maxiter = 200timeit("mandelbrot_numpy1(xmin, xmax, ymin, ymax, xn, yn, maxiter)", globals())# Visualizationxmin, xmax, xn = -2.25, +0.75, int(3000/2)ymin, ymax, yn = -1.25, +1.25, int(2500/2)maxiter = 20horizon = 2.0 ** 40log_horizon = np.log(np.log(horizon))/np.log(2)Z, N = mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon)# Normalized recount as explained in:# http://linas.org/art-gallery/escape/smooth.htmlM = np.nan_to_num(N + 1 - np.log(np.log(abs(Z)))/np.log(2) + log_horizon)dpi = 72width = 10height = 10*yn/xnfig = plt.figure(figsize=(width, height), dpi=dpi)ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], frameon=False, aspect=1)light = colors.LightSource(azdeg=315, altdeg=10)plt.imshow(light.shade(M, cmap=plt.cm.hot, vert_exag=1.5,norm = colors.PowerNorm(0.3), blend_mode='hsv'),extent=[xmin, xmax, ymin, ymax], interpolation="bicubic")ax.set_xticks([])ax.set_yticks([])plt.savefig("output/mandelbrot.png")plt.show()
Here is the benchmark:
timeit("mandelbrot_python(xmin, xmax, ymin, y max, xn, yn, maxiter)", globals())
#1 loops, best of 3: 6.1 sec per loop
timeit("mandelbrot_numpy1(xmin, xmax, ymin, ymax, xn, yn, maxiter)",
...