JAX是一个由Google开发的用于优化科学计算Python库:
它可以被视为GPU和TPU上运行的NumPy,jax.numpy提供了与numpy非常相似API接口。
它与NumPyAPI非常相似,几乎任何可以用numpy完成的事情都可以用jax.numpy完成。
由于使用XLA(一种加速线性代数计算的编译器)将Python和JAX代码JIT编译成优化的内核,可以在不同设备(例如gpu和tpu)上运行。而优化的内核是为高吞吐量设备(例如gpu和tpu)进行编译,它与主程序分离但可以被主程序调用。JIT编译可以用jax.jit()触发。
它对自动微分有很好的支持,对机器学习研究很有用。可以使用jax.grad()触发自动区分。
JAX鼓励函数式编程,因为它的面向函数的。与NumPy数组不同,JAX数组始终是不可变的。
JAX提供了一些在编写数字处理时非常有用的程序转换,例如JIT.JAX()用于JIT编译和加速代码,JIT.grad()用于求导,以及JIT.vmap()用于自动向量化或批处理。
JAX可以进行异步调度。所以需要调用.block_until_ready()以确保计算已经实际发生。
JAX使用JIT编译有两种方式:
自动:在执行JAX函数的库调用时,默认情况下JIT编译会在后台进行。
手动:您可以使用jax.jit()手动请求对自己的Python函数进行JIT编译。
JAX使用示例
我们可以使用pip安装库。
pipinstalljax
导入需要的包,这里我们也继续使用NumPy,这样可以执行一些基准测试。
importjaximportjax.numpyasjnpfromjaximportrandomfromjaximportgrad,jitimportnumpyasnpkey=random.PRNGKey(0)
与importnumpyasnp类似,我们可以importjax.numpyasjnp并将代码中的所有np替换为jnp。如果NumPy代码是用函数式编程风格编写的,那么新的JAX代码就可以直接使用。但是,如果有可用的GPU,JAX则可以直接使用。
JAX中随机数的生成方式与NumPy不同。JAX需要创建一个jax.random.PRNGKey。我们稍后会看到如何使用它。
我们在GoogleColab上做一个简单的基准测试,这样我们就可以轻松访问GPU和TPU。我们首先初始化一个包含25M元素的随机矩阵,然后将其乘以它的转置。使用针对CPU优化的NumPy,矩阵乘法平均需要1.61秒。
#runsonCPU-numpysize=x=np.random.normal(size=(size,size)).astype(np.float32)%timeitnp.dot(x,x.T)#1loop,bestof5:1.61sperloop
在CPU上使用JAX执行相同的操作平均需要大约3.49秒。
#runsonCPU-JAXsize=x=random.normal(key,(size,size),dtype=jnp.float32)%timeitjnp.dot(x,x.T).block_until_ready()#1loop,bestof5:3.49sperloop
在CPU上运行时,JAX通常比NumPy慢,因为NumPy已针对CPU进行了非常多的优化。但是,当使用加速器时这种情况会发生变化,所以让我们尝试使用GPU进行矩阵乘法。
#runsonGPUsize=x=random.normal(key,(size,size),dtype=jnp.float32)%timex_jax=jax.device_put(x)#1.measureJAXdevicetransfertime%timejnp.dot(x_jax,x_jax.T).block_until_ready()#2.measureJAX