python 作为一种解释型语言,通常运行速度较慢,特别是在科学计算领域。为了提高运行效率,通常需要使用一些加速方法。本文将简单介绍如何使用这些方法加速 Python 科学计算。

首先,我们定义一个 Mandelbrot 函数,用于演示加速方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import numpy as np
import time

def MandNumba(ext, max_steps, Nx, Ny):
data = np.ones((Nx, Ny)) * max_steps
for i in range(Nx):
for j in range(Ny):
x = ext[0] + (ext[1] - ext[0]) * i / (Nx - 1.)
y = ext[2] + (ext[3] - ext[2]) * j / (Ny - 1.)
z0 = x + y * 1j
z = 0j
for itr in range(max_steps):
if abs(z) > 2.:
data[j, i] = itr
break
z = z * z + z0
return data
1
2
3
4
5
6
7
8
9
10
Nx = 1000
Ny = 1000
max_steps = 1000 # 50

ext = [-2, 1, -1, 1]

t0 = time.time()
data = MandNumba(np.array(ext), max_steps, Nx, Ny)
t1 = time.time()
print('clock time: ', t1 - t0)
clock time:  29.383343935012817

可以看到不进行任何加速需要 29.38s ,非常慢。

以下是几种加速该函数的Python科学计算方法:

  1. Numba JIT 并行加速
  2. 多进程并行
  3. NumPy 向量化

还有其他加速方法,如 CythonNumba CUDA 等,这些方法需要GPU硬件支持,这里不做介绍。

1. Numba JIT 并行加速

Numba 是一个开源的 JIT 编译器,可以将 Python 代码编译成机器码,从而提高运行速度。Numba 支持并行计算,可以通过 @numba.njit(parallel=True) 修饰器实现并行计算。

这里解释一下什么是修饰器,在 Python 中,修饰器(Decorator)是一种设计模式,用于在不修改原始函数代码的情况下,动态地给函数添加新的功能。修饰器本质上是一个函数,它接收一个函数作为参数,并返回一个新的函数。

在这里,@numba.njit(parallel=True) 修饰器将 MandNumba_parallel 函数编译成机器码,并实现并行计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import numpy as np
import numba
import time

@numba.njit(parallel=True)
def MandNumba_parallel(ext, max_steps, Nx, Ny):
data = np.ones((Ny, Nx), dtype=np.int32) * max_steps
for i in numba.prange(Nx): # 并行外层循环
for j in range(Ny):
x = ext[0] + (ext[1] - ext[0]) * i / (Nx - 1)
y = ext[2] + (ext[3] - ext[2]) * j / (Ny - 1)
z0 = x + y * 1j
z = 0j
for itr in range(max_steps):
if abs(z) > 2:
data[j, i] = itr
break
z = z * z + z0
return data
1
2
3
4
5
6
7
8
9
10
Nx = 1000
Ny = 1000
max_steps = 1000 # 50

ext = [-2, 1, -1, 1]

t0 = time.time()
data = MandNumba_parallel(np.array(ext), max_steps, Nx, Ny)
t1 = time.time()
print('clock time: ', t1 - t0)
clock time:  0.8841660022735596

这里只花了 0.88s ,加速了约 33 倍。

2. 多进程并行

多进程并行是一种利用多核 CPU 并行计算的方法,适合 CPU 密集型任务。

就是让每个人算一行,然后把结果合并起来。这是最朴素的并行计算方法,也是最容易实现的方法。

Python 提供了 multiprocessing 模块,可以很方便地实现多进程并行计算。

由于 multiprocessing 模块的限制,被并行计算的函数不能是类的成员函数,也不能是全局函数。因此,我们需要将 compute_row 函数单独放在一个文件中,然后通过 multiprocessing 模块调用。

1
!ls
__pycache__           compute_one_row.py    numpyAccelerate.ipynb
1
!cat compute_one_row.py
import numpy as np

def compute_row(ext, max_steps, Nx, Ny, row):
    result = np.empty(Ny, dtype=np.int64)
    for j in range(Ny):
        x = ext[0] + (ext[1] - ext[0]) * row / (Nx - 1.)
        y = ext[2] + (ext[3] - ext[2]) * j / (Ny - 1.)
        z0 = x + y * 1j
        z = 0j
        for itr in range(max_steps):
            if abs(z) > 2.:
                result[j] = itr
                break
            z = z * z + z0
        else:
            result[j] = max_steps
    return result
1
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np
import multiprocessing as mp
from compute_one_row import compute_row
import time

def MandelMultiProcess(ext, max_steps, Nx, Ny):
data = np.ones((Nx, Ny), dtype=np.int64) * max_steps
with mp.Pool(processes=mp.cpu_count()) as pool:
results = [pool.apply_async(compute_row, (ext, max_steps, Nx, Ny, i)) for i in range(Nx)]
for i in range(Nx):
data[i, :] = results[i].get()
return data

1
2
3
4
5
6
7
8
9
10
Nx = 1000
Ny = 1000
max_steps = 1000 # 50

ext = [-2, 1, -1, 1]

t0 = time.time()
data = MandelMultiProcess(np.array(ext), max_steps, Nx, Ny)
t1 = time.time()
print('clock time: ', t1 - t0)
clock time:  3.7328829765319824

我使用的是 Macbook Pro M2 Pro 版本,有6个性能核,4个能效核,总共10个核。所耗费的时间是 3.73s ,加速了约 8 倍。

4. NumPy 向量化

NumPy 是 Python 科学计算的基础库,提供了很多高效的数学函数和运算符。NumPy 的向量化操作可以大大提高运算速度。

下面的代码是将 MandNumba 函数向量化,使用 NumPy 的数组运算代替循环运算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import numpy as np


def MandNumba_vectorized(ext, max_steps, Nx, Ny):
x = np.linspace(ext[0], ext[1], Nx)
y = np.linspace(ext[2], ext[3], Ny)
X, Y = np.meshgrid(x, y)
Z0 = X + Y * 1j
Z = np.zeros_like(Z0)
data = np.full(Z0.shape, max_steps, dtype=int)

mask = np.ones_like(Z0, dtype=bool)
for itr in range(max_steps):
Z[mask] = Z[mask] ** 2 + Z0[mask]
escaped = (np.abs(Z) > 2) & mask
data[escaped] = itr
mask[escaped] = False
if not mask.any():
break
return data
1
2
3
4
5
6
7
8
9
Nx = 1000
Ny = 1000
max_steps = 1000 # 50

ext = [-2, 1, -1, 1]
t0 = time.time()
data = MandNumba_vectorized(np.array(ext), max_steps, Nx, Ny)
t1 = time.time()
print('clock time: ', t1 - t0)
clock time:  5.844282150268555

这里花费了 5.84s ,加速了约 5 倍。

性能对比说明:

  1. Numba JIT:通过即时编译优化循环,通常可获得10-100倍加速
  2. 多进程:适合CPU密集型任务,但进程间通信可能成为瓶颈
  3. 向量化:对小规模计算友好,但内存消耗随问题规模平方增长
    可以看到,Numba JIT 是最快的,其本质是将 Python 代码编译成机器码,作弊变成C++fortran,消除了 Python 解释器的性能瓶颈。

最后感谢 DeepseekGithub Copilot 的帮助。

并行后的 mandf-dynamic.py 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#%matplotlib notebook
import numpy as np
import pylab as plt
import time
import numba

@numba.njit(parallel=True)
def MandNumba(ext, max_steps, Nx, Ny):
data = np.ones((Nx, Ny), dtype=np.int32) * max_steps
for i in range(Nx):
for j in range(Ny):
x = ext[0] + (ext[1] - ext[0]) * i / (Nx - 1.)
y = ext[2] + (ext[3] - ext[2]) * j / (Ny - 1.)
z0 = x + y * 1j
z = 0j
for itr in range(max_steps):
if abs(z) > 2.:
data[j, i] = itr
break
z = z * z + z0
return data

def ax_update(ax): # actual plotting routine
ax.set_autoscale_on(False) # Otherwise, infinite loop
# Get the range for the new area
xstart, ystart, xdelta, ydelta = ax.viewLim.bounds
xend = xstart + xdelta
yend = ystart + ydelta
ext=np.array([xstart,xend,ystart,yend])
data = MandNumba(ext, max_steps, Nx, Ny) # actually producing new fractal

# Update the image object with our new data and extent
im = ax.images[-1] # take the latest object
im.set_data(data) # update it with new data
im.set_extent(ext) # change the extent
ax.figure.canvas.draw_idle() # finally redraw

if __name__ == '__main__':
Nx = 1000
Ny = 1000
max_steps = 1000 # 50

ext = [-2,1,-1,1]

t0 = time.time()
data = MandNumba(np.array(ext), max_steps, Nx, Ny)
t1 = time.time()
print('clock time: ', t1-t0)

fig,ax=plt.subplots(1,1)
ax.imshow(data, extent=ext,aspect='equal',origin='lower',cmap='plasma')

ax.callbacks.connect('xlim_changed', ax_update)
ax.callbacks.connect('ylim_changed', ax_update)
plt.show()

clock time:  1.4002361297607422