# 💡 Summary 和 TensorFlow / PyTorch 的直觉对比: - JAX:更 NumPy、更函数式,变换(`grad/jit/vmap/pjit`)是“第一公民”,在 TPU 上体验突出。 - PyTorch:动态图、生态与工程工具链更成熟(推理部署、ONNX、社区模型多)。 - TensorFlow:图/静态执行与企业部署链条强,但研究端近年更多人转向 JAX 或 PyTorch。 JAX 是 Google Research 推出的**高性能数值计算与自动微分库**。它有三件“看家本领”: 1. **NumPy 风格**:API 和 `numpy` 几乎一致(用的是 `jax.numpy`),上手成本低。 2. **可组合的程序变换**: - `jax.grad` 自动求导 - `jax.jit` 一键加速(用 XLA 编译到 CPU/GPU/TPU) - `jax.vmap` 自动向量化(把“对一个样本的函数”变成“对一批样本的函数”) - `jax.pmap` / `jax.pjit` 跨多卡/多机并行与分片 3. **纯函数/函数式风格**:强调不可变状态,更利于推理、并行与编译优化。 为什么研究和大模型爱用它: - **快**:`jit` + XLA 编译,算子融合、内存优化到位。 - **简**:把“写清楚数学函数”交给你,把“怎么并行、怎么编译”交给 JAX 的变换器。 - **大规模**:`pmap/pjit` + SPMD 分片,天然适合 TPU/GPU 集群训练。 - **生态**:常见搭子包括 - 模型框架:**Flax**、**Haiku**、**Equinox** - 优化器:**Optax** - 概率编程:**NumPyro** - 图与工具:**Jraph**、**Chex**、**Orbax** 等 (例如 AlphaFold、T5X/Pax 等研究系统都基于 JAX/其生态的变体。) # 🧩 Cues # 🪞Notes