前言

Python 因其简洁优雅的语法和丰富的生态系统而广受欢迎,但在计算密集型任务中,Python 的执行速度往往成为瓶颈。虽然我们可以使用 C/C++ 扩展或者 Cython 来提升性能,但这些方案的学习成本和开发复杂度都比较高。

Numba 的出现改变了这一切!它是一个针对 Python 的即时(JIT)编译器,能够将 Python 代码直接编译为机器码,实现接近 C 语言的执行速度。最关键的是,你几乎不需要修改现有的 Python 代码,只需要添加一个装饰器就能获得 10-100 倍的性能提升!

本文将全面介绍 Numba 的使用方法,从基础概念到高级技巧,帮助你掌握这个强大的性能优化工具。

1. Numba 简介和安装

1.1 什么是 Numba?

Numba 是一个开源的 JIT 编译器,它使用 LLVM 编译器库将 Python 函数编译为优化的机器码。Numba 专门针对 NumPy 数组和数值计算进行了优化,能够显著提升数值计算代码的性能。

Numba 的主要特点:

  • 易于使用:只需添加装饰器,无需重写代码
  • 高性能:能够实现接近 C 语言的执行速度
  • 兼容性好:支持大部分 NumPy 功能和 Python 语法
  • 自动优化:自动进行循环优化、向量化等
  • 并行支持:支持 CPU 和 GPU 并行计算

1.2 安装 Numba

使用 pip 安装 Numba:

# 注意如果要使用numba 建议使用 python3.9或3.10
pip install numba
# 下面的版本实测不会产生依赖冲突
# numba==0.56.4 
# numpy==1.23.5
# llvmlite==0.39.1

1.3 第一个 Numba 程序

让我们从一个简单的例子开始,体验 Numba 的威力:

import numpy as np
import time
from numba import jit

def python_sum(arr):
    """传统 Python 实现"""
    total = 0.0
    for i in range(len(arr)):
        total += arr[i]
    return total

@jit
def numba_sum(arr):
    """Numba 优化版本"""
    total = 0.0
    for i in range(len(arr)):
        total += arr[i]
    return total

# 测试性能
if __name__ == "__main__":
    # 生成测试数据
    data = np.random.random(1000000)
    
    # 预热 Numba 函数
    _ = numba_sum(data[:100])
    
    # 测试 Python 版本
    start_time = time.time()
    result_python = python_sum(data)
    python_time = time.time() - start_time
    
    # 测试 Numba 版本
    start_time = time.time()
    result_numba = numba_sum(data)
    numba_time = time.time() - start_time
    
    print(f"Python 版本耗时: {python_time:.4f}秒")
    print(f"Numba 版本耗时:  {numba_time:.4f}秒")
    print(f"性能提升: {python_time/(numba_time if numba_time > 0 else 0.0001):.1f}倍")
    print(f"结果一致: {abs(result_python - result_numba) < 1e-10}")

运行这个例子,你会发现 Numba 版本比 Python 版本快了几十甚至上百倍!这就是 Numba 的魅力所在。
在这里插入图片描述

2. Numba 基础语法和装饰器

2.1 @jit 装饰器

@jit 是 Numba 最基本的装饰器,它会在函数首次调用时将 Python 代码编译为机器码。

from numba import jit
import numpy as np

@jit
def calculate_distance(x1, y1, x2, y2):
    """计算两点距离"""
    return np.sqrt((x2 - x1)**2 + (y2 - y1)**2)

# 使用示例
distance = calculate_distance(0, 0, 3, 4)
print(f"距离: {distance}")  # 输出: 5.0

2.2 @njit 装饰器

@njit@jit(nopython=True) 的简写,它强制 Numba 完全脱离 Python 解释器运行,通常能获得更好的性能:

import numba as nb
import numpy as np
import time

def pure_python_sum(arr: np.ndarray) -> float:
    """纯Python版本的数组求和"""
    total = 0.0
    for i in range(len(arr)):
        total += arr[i]
    return total

@nb.njit
def numba_sum(arr: np.ndarray) -> float:
    """Numba加速版本的数组求和"""
    total = 0.0
    for i in range(len(arr)):
        total += arr[i]
    return total

test_array = np.random.random(10000000) * 100

start_time = time.time()
result_python = pure_python_sum(test_array)
python_time = time.time() - start_time

start_time = time.time()
result_numba = numba_sum(test_array)
numba_time = time.time() - start_time

start_time = time.time()
result_numba2 = numba_sum(test_array)
numba_time2 = time.time() - start_time

print(f"纯Python耗时: {python_time:.6f}秒")
print(f"Numba首次调用(含编译): {numba_time:.6f}秒")
print(f"Numba第二次调用: {numba_time2:.6f}秒")
print(f"性能提升: {python_time / (numba_time2 if numba_time2 > 0 else 0.0001):.1f}倍")
print(f"结果一致性: {np.isclose(result_python, result_numba)}")
print(f"结果一致性: {np.isclose(result_python, result_numba2)}")

在这里插入图片描述

2.3 编译模式和选项

Numba 提供多种编译模式和选项来控制编译行为:

from numba import njit, prange
import numpy as np

# 缓存编译结果,避免重复编译
@njit(cache=True)
def cached_function(x):
    return x * x + 2 * x + 1
 
# 指定函数签名,提前编译
@njit("float64[:](float64[:])")
def typed_function(x):
    return np.sin(x) + np.cos(x)

# 启用并行计算(注意:使用numba并行计算, 循环需要使用prange, 原始的range不支持并行)
@njit(parallel=True)
def parallel_function(arr):
    result = np.zeros_like(arr)
    for i in prange(len(arr)):
        result[i] = arr[i] ** 2 + arr[i] ** 0.5
    return result

# 错误处理模式
@njit(error_model='numpy')
def error_safe_function(x):
    return np.sqrt(x)  # 对负数返回 NaN 而不是抛出异常

# 使用示例
x = np.linspace(-10, 10, 1000)
y1 = cached_function(x)
y2 = typed_function(x)
y3 = parallel_function(np.abs(x))
y4 = error_safe_function(x)  # 包含负数,会产生 NaN

print(f"缓存函数结果: {y1[:5]}")
print(f"类型化函数结果: {y2[:5]}")
print(f"并行函数结果: {y3[:5]}")
print(f"错误安全函数结果: {y4[:5]}")

3. Numba 支持的数据类型和操作

3.1 支持的数据类型

Numba 支持大部分 NumPy 数据类型和 Python 基本类型:

from numba import njit, types
import numpy as np

@njit
def data_types_demo():
    """演示 Numba 支持的数据类型"""
    # 基本数值类型
    int_val = 42
    float_val = 3.14
    complex_val = 1.0 + 2.0j
    bool_val = True
    
    # NumPy 数组
    int_array = np.array([1, 2, 3], dtype=np.int32)
    float_array = np.array([1.0, 2.0, 3.0], dtype=np.float64)
    bool_array = np.array([True, False, True])
    
    # 多维数组
    matrix = np.zeros((3, 3), dtype=np.float32)
    matrix[0, 0] = 1.0
    matrix[1, 1] = 1.0
    matrix[2, 2] = 1.0

    trace_val = 0.0
    for i in range(matrix.shape[0]):
        trace_val += matrix[i, i]
    
    # 元组
    coordinates = (1.0, 2.0, 3.0)
    
    return (int_val, float_val, complex_val, bool_val, 
            int_array.sum(), float_array.mean(),  bool_array.sum(),
            trace_val, coordinates[0])

# 显式类型声明
@njit("Tuple((int64, float64, complex128))(float64[:])")
def explicit_types(arr):
    """使用显式类型声明"""
    total = arr.sum()
    mean = arr.mean()
    complex_result = total + 1j * mean
    return int(total), mean, complex_result

# 测试
result = data_types_demo()
print(f"数据类型演示结果: {result}")

arr = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
typed_result = explicit_types(arr)
print(f"显式类型结果: {typed_result}")

3.2 NumPy 函数支持

Numba 支持大量的 NumPy 函数和操作:

from numba import njit
import numpy as np

@njit
def numpy_functions_demo(x, y):
    """演示 Numba 支持的 NumPy 函数"""
    # 数学函数
    sin_x = np.sin(x)
    cos_x = np.cos(x)
    exp_x = np.exp(x)
    log_x = np.log(np.abs(x) + 1e-10)  # 避免 log(0)
    
    # 统计函数
    mean_val = np.mean(x)
    std_val = np.std(x)
    min_val = np.min(x)
    max_val = np.max(x)
    
    # 数组操作
    sorted_x = np.sort(x)
    unique_x = np.unique(x.astype(np.int32))
    
    # 线性代数(部分支持)
    dot_product = np.dot(x, y)
    
    # 逻辑操作
    mask = x > 0
    positive_x = x[mask]
    
    # 数组创建
    zeros = np.zeros(10)
    ones = np.ones(5)
    arange = np.arange(0, 10, 2)
    
    return {
        'sin_mean': np.mean(sin_x),
        'cos_std': np.std(cos_x),
        'exp_max': np.max(exp_x),
        'log_min': np.min(log_x),
        'stats': (mean_val, std_val, min_val, max_val),
        'sorted_first': sorted_x[0],
        'unique_count': len(unique_x),
        'dot_product': dot_product,
        'positive_count': len(positive_x),
        'created_arrays': (len(zeros), len(ones), len(arange))
    }

# 注意:Numba 不支持返回字典,这里仅作演示
# 实际使用时应该返回元组或数组

@njit
def numpy_functions_practical(x, y):
    """实用的 NumPy 函数演示"""
    # 数学运算
    result1 = np.sqrt(x**2 + y**2)
    result2 = np.arctan2(y, x)
    
    # 统计分析
    correlation = np.corrcoef(x, y)[0, 1]
    
    # 数组处理
    combined = np.concatenate((x, y))
    reshaped = combined.reshape(-1, 1)
    
    return result1.mean(), result2.std(), correlation, reshaped.shape[0]

# 测试
x = np.random.randn(1000)
y = np.random.randn(1000)

mean_dist, std_angle, corr, total_len = numpy_functions_practical(x, y)
print(f"平均距离: {mean_dist:.4f}")
print(f"角度标准差: {std_angle:.4f}")
print(f"相关系数: {corr:.4f}")
print(f"总长度: {total_len}")

4. 控制流和循环优化

4.1 循环优化

Numba 对循环进行了特殊优化,能够自动向量化和并行化循环:

from numba import njit, prange
import numpy as np
import time

@njit
def sequential_loop(arr):
    """顺序循环"""
    result = np.zeros_like(arr)
    for i in range(len(arr)):
        result[i] = arr[i] ** 2 + np.sin(arr[i]) + np.cos(arr[i])
    return result

@njit(parallel=True)
def parallel_loop(arr):
    """并行循环"""
    result = np.zeros_like(arr)
    for i in prange(len(arr)):  # 使用 prange 启用并行
        result[i] = arr[i] ** 2 + np.sin(arr[i]) + np.cos(arr[i])
    return result

@njit
def nested_loops_optimization(matrix):
    """嵌套循环优化"""
    rows, cols = matrix.shape
    result = np.zeros_like(matrix)
    
    # Numba 会自动优化这种嵌套循环
    for i in range(rows):
        for j in range(cols):
            # 复杂的计算
            val = matrix[i, j]
            result[i, j] = val**3 - 2*val**2 + val + 1
    
    return result

@njit
def loop_with_conditions(arr, threshold):
    """带条件的循环优化"""
    count = 0
    total = 0.0
    
    for i in range(len(arr)):
        if arr[i] > threshold:
            total += arr[i]
            count += 1
        elif arr[i] < -threshold:
            total -= arr[i]
            count += 1
    
    return total / count if count > 0 else 0.0

# 性能测试
def test_loop_performance():
    """测试循环性能"""
    data = np.random.randn(1000000)
    matrix = np.random.randn(1000, 1000)
    
    # 预热
    _ = sequential_loop(data[:100])
    _ = parallel_loop(data[:100])
    
    # 顺序循环测试
    start_time = time.time()
    result1 = sequential_loop(data)
    seq_time = time.time() - start_time
    
    # 并行循环测试
    start_time = time.time()
    result2 = parallel_loop(data)
    par_time = time.time() - start_time
    
    # 嵌套循环测试
    start_time = time.time()
    result3 = nested_loops_optimization(matrix)
    nested_time = time.time() - start_time
    
    # 条件循环测试
    start_time = time.time()
    result4 = loop_with_conditions(data, 0.5)
    cond_time = time.time() - start_time
    
    print(f"顺序循环耗时: {seq_time:.4f}秒")
    print(f"并行循环耗时: {par_time:.4f}秒")
    print(f"并行加速比: {seq_time/par_time:.2f}x")
    print(f"嵌套循环耗时: {nested_time:.4f}秒")
    print(f"条件循环耗时: {cond_time:.4f}秒")
    print(f"条件循环结果: {result4:.4f}")
    
    # 验证结果一致性
    print(f"结果一致性: {np.allclose(result1, result2)}")

# 运行性能测试
test_loop_performance()

4.2 条件语句优化

Numba 能够高效处理条件语句和分支预测:

from numba import njit
import numpy as np

@njit
def conditional_optimization(x, y):
    """条件语句优化示例"""
    result = np.zeros_like(x)
    
    for i in range(len(x)):
        if x[i] > 0 and y[i] > 0:
            # 第一象限
            result[i] = x[i] + y[i]
        elif x[i] < 0 and y[i] > 0:
            # 第二象限
            result[i] = -x[i] + y[i]
        elif x[i] < 0 and y[i] < 0:
            # 第三象限
            result[i] = -x[i] - y[i]
        else:
            # 第四象限
            result[i] = x[i] - y[i]
    
    return result

@njit
def vectorized_conditions(x, y):
    """向量化条件处理"""
    # 使用 NumPy 的 where 函数进行向量化条件处理
    quad1 = (x > 0) & (y > 0)
    quad2 = (x < 0) & (y > 0)
    quad3 = (x < 0) & (y < 0)
    quad4 = ~(quad1 | quad2 | quad3)
    
    result = np.zeros_like(x)
    result = np.where(quad1, x + y, result)
    result = np.where(quad2, -x + y, result)
    result = np.where(quad3, -x - y, result)
    result = np.where(quad4, x - y, result)
    
    return result

@njit
def complex_branching(data, mode):
    """复杂分支逻辑"""
    n = len(data)
    result = np.zeros(n)
    
    for i in range(n):
        val = data[i]
        
        if mode == 1:
            if val > 0:
                result[i] = np.sqrt(val)
            else:
                result[i] = 0
        elif mode == 2:
            if val > 1:
                result[i] = np.log(val)
            elif val > 0:
                result[i] = val
            else:
                result[i] = -val
        else:
            result[i] = np.abs(val)
    
    return result

# 测试条件语句优化
x = np.random.randn(100000)
y = np.random.randn(100000)

result1 = conditional_optimization(x, y)
result2 = vectorized_conditions(x, y)

print(f"条件优化结果一致性: {np.allclose(result1, result2)}")

# 测试复杂分支
data = np.random.randn(10000)
for mode in [1, 2, 3]:
    result = complex_branching(data, mode)
    print(f"模式 {mode} 处理完成,平均值: {np.mean(result):.4f}")

5. 并行计算和prange

5.1 基础并行计算

Numba 提供了简单易用的并行计算功能,通过 prange 可以轻松实现多线程并行:

from numba import njit, prange
import numpy as np
import time

@njit
def serial_computation(data):
    """串行计算"""
    result = np.zeros_like(data)
    for i in range(len(data)):
        # 模拟复杂计算
        temp = data[i]
        for j in range(100):
            temp = temp * 0.99 + 0.01 * np.sin(temp)
        result[i] = temp
    return result

@njit(parallel=True)
def parallel_computation(data):
    """并行计算"""
    result = np.zeros_like(data)
    for i in prange(len(data)):  # 使用 prange 实现并行
        # 相同的复杂计算
        temp = data[i]
        for j in range(100):
            temp = temp * 0.99 + 0.01 * np.sin(temp)
        result[i] = temp
    return result

@njit(parallel=True)
def parallel_matrix_operations(matrix):
    """并行矩阵操作"""
    rows, cols = matrix.shape
    result = np.zeros_like(matrix)
    
    # 并行处理每一行
    for i in prange(rows):
        for j in range(cols):
            # 对每个元素进行复杂变换
            val = matrix[i, j]
            result[i, j] = np.exp(-val**2) * np.cos(val) + np.sin(val)
    
    return result

@njit(parallel=True)
def parallel_reduction(data):
    """并行归约操作"""
    n = len(data)
    # 计算平方和
    sum_squares = 0.0
    for i in prange(n):
        sum_squares += data[i] ** 2
    
    return sum_squares

# 性能对比测试
def benchmark_parallel():
    """并行计算性能基准测试"""
    data = np.random.randn(10000)
    matrix = np.random.randn(500, 500)
    
    # 预热函数
    _ = serial_computation(data[:100])
    _ = parallel_computation(data[:100])
    
    print("=" * 50)
    print("并行计算性能测试")
    print("=" * 50)
    
    # 测试一维数组处理
    start = time.time()
    result_serial = serial_computation(data)
    serial_time = time.time() - start
    
    start = time.time()
    result_parallel = parallel_computation(data)
    parallel_time = time.time() - start
    
    print(f"一维数组处理:")
    print(f"  串行耗时: {serial_time:.4f}秒")
    print(f"  并行耗时: {parallel_time:.4f}秒")
    print(f"  加速比: {serial_time/parallel_time:.2f}x")
    print(f"  结果一致: {np.allclose(result_serial, result_parallel)}")
    
    # 测试矩阵操作
    start = time.time()
    matrix_result = parallel_matrix_operations(matrix)
    matrix_time = time.time() - start
    
    print(f"\n矩阵操作:")
    print(f"  并行耗时: {matrix_time:.4f}秒")
    print(f"  处理速度: {matrix.size/matrix_time/1000:.1f}K 元素/秒")
    
    # 测试归约操作
    start = time.time()
    sum_result = parallel_reduction(data)
    reduction_time = time.time() - start
    
    # 验证结果
    expected_sum = np.sum(data**2)
    print(f"\n归约操作:")
    print(f"  并行耗时: {reduction_time:.6f}秒")
    print(f"  结果验证: {abs(sum_result - expected_sum) < 1e-10}")

# 运行基准测试
benchmark_parallel()

6. JitClass - 类的编译优化

6.1 JitClass 基础

Numba 允许使用 @jitclass 装饰器来编译类,实现高性能的面向对象编程:

from numba import jitclass, njit, types
import numpy as np

# 定义类的数据结构
spec = [
    ('value', types.float64),
    ('data', types.float64[:]),
    ('size', types.int64)
]

@jitclass(spec)
class FastArray:
    """高性能数组类"""
    
    def __init__(self, size):
        self.size = size
        self.data = np.zeros(size)
        self.value = 0.0
    
    def set_value(self, val):
        """设置标量值"""
        self.value = val
    
    def fill(self, val):
        """填充数组"""
        for i in range(self.size):
            self.data[i] = val
    
    def add_scalar(self, val):
        """数组加标量"""
        for i in range(self.size):
            self.data[i] += val
    
    def multiply_scalar(self, val):
        """数组乘标量"""
        for i in range(self.size):
            self.data[i] *= val
    
    def dot_product(self, other):
        """计算与另一个 FastArray 的点积"""
        if self.size != other.size:
            return -1.0  # 错误标志
        
        result = 0.0
        for i in range(self.size):
            result += self.data[i] * other.data[i]
        return result
    
    def norm(self):
        """计算向量的模长"""
        sum_squares = 0.0
        for i in range(self.size):
            sum_squares += self.data[i] * self.data[i]
        return np.sqrt(sum_squares)
    
    def normalize(self):
        """归一化向量"""
        norm_val = self.norm()
        if norm_val > 1e-10:
            for i in range(self.size):
                self.data[i] /= norm_val

# 使用 JitClass 的函数
@njit
def vector_operations(size):
    """演示 JitClass 的使用"""
    # 创建两个向量
    vec1 = FastArray(size)
    vec2 = FastArray(size)
    
    # 初始化向量
    vec1.fill(1.0)
    vec2.fill(2.0)
    
    # 执行操作
    vec1.add_scalar(0.5)  # vec1 现在是 [1.5, 1.5, ...]
    vec2.multiply_scalar(1.5)  # vec2 现在是 [3.0, 3.0, ...]
    
    # 计算点积
    dot = vec1.dot_product(vec2)
    
    # 计算模长
    norm1 = vec1.norm()
    norm2 = vec2.norm()
    
    # 归一化
    vec1.normalize()
    vec2.normalize()
    
    # 归一化后的模长
    norm1_after = vec1.norm()
    norm2_after = vec2.norm()
    
    return dot, norm1, norm2, norm1_after, norm2_after

# 测试 JitClass
size = 10000
dot, norm1, norm2, norm1_after, norm2_after = vector_operations(size)

print(f"向量尺寸: {size}")
print(f"点积: {dot}")
print(f"归一化前模长: vec1={norm1:.6f}, vec2={norm2:.6f}")
print(f"归一化后模长: vec1={norm1_after:.6f}, vec2={norm2_after:.6f}")

结语

Numba 是一个强大的 Python 性能优化工具,它让我们能够在保持 Python 简洁性的同时获得接近 C 语言的执行速度。经历大量尝试之后总结下面的“教训”:

  1. 优先考虑使用@nb.njit装饰器进行JIT编译(或者使用@nb.jit(nopython=True)
  2. 尽量使用原生数据类型(int, float等)
  3. 避免使用Python容器(list/dict),优先使用NumPy数组
  4. 循环密集型计算性能提升最明显, 将复杂算法分解为多个简单函数
  5. 不要在Numba函数中使用print()调试, 建议先用纯Python版本验证逻辑
  6. 适合数值计算,不适合字符串/对象操作
  7. 使用@jitclass必须显式声明所有属性的类型, 构造函数参数类型必须明确
  8. jitclass类方法中不能调用不支持的Python函数

Numba 让 Python 在数值计算领域真正实现了"原地起飞",希望本文能帮助你更好地掌握这个强大的工具,在你的项目中发挥其威力!


相关链接:

示例代码仓库:
本文所有示例代码都经过测试验证,你可以直接复制运行。建议在博文直接运行查看结果,或者也可以复制代码到本地逐个测试这些示例,以加深理解。

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐