谷歌狂喜:JAX性能超越Pytorch、TensorFlow!或成GPU推理训练最快选择

AIGC动态8个月前发布 AIera
862 0 0
谷歌狂喜:JAX性能超越Pytorch、TensorFlow!或成GPU推理训练最快选择

 

文章摘要


【关 键 词】 JAX性能测试谷歌PyTorchTensorFlow

新智元最近的报道关注了JAX这一谷歌力推的平台,它在最新的基准测试中表现出色,超越了PyTorchTensorFlow

这些测试结果表明,JAX在七项指标中排名第一,而且这些测试并未在JAX最擅长的TPU上进行,这意味着其性能可能更加出色。

尽管目前Pytorch在开发者中的受欢迎程度超过TensorFlow,但随着JAX的崛起,未来可能会有更多大型模型选择在JAX平台上进行训练和运行。

Keras团队对TensorFlow、JAX、PyTorch三个后端以及原生PyTorch实现和搭配TensorFlow的Keras 2进行了基准测试。

测试选取了一系列主流的计算机视觉和自然语言处理模型,包括来自HuggingFace Transformers的BERT、Gemma、Mistral,来自HuggingFace Diffusers的StableDiffusion,以及来自Meta的SegmentAnything。

测试使用合成数据,并在所有大型语言模型训练和推理中采用了bfloat16精度,同时在训练中使用了LoRA微调技术。

为了衡量性能,测试使用了高级API,并尽量减少了配置。

硬件配置方面,所有基准测试都在Google Cloud Compute Engine上进行,配置包括一块40GB显存的NVIDIA A100 GPU、12个虚拟CPU和85GB的主机内存。

测试结果显示,JAX在多个方面表现优异,但也发现不存在一个始终领先的后端,不同后端在不同模型架构下的表现各有千秋。

Keras 3的性能普遍超过了原生PyTorch的标准实现,在10个测试任务中,有5个的速度提升超过了50%,其中最高达到了290%。

关键发现包括:没有一个后端始终是最优的,选择哪个后端最快取决于模型架构;Keras 3的性能普遍超过了原生PyTorch;Keras 3提供了卓越的开箱即用性能,无需用户进行深入的性能优化;Keras 3的表现始终优于Keras 2,显示了显著的性能提升。

这些发现表明,JAX和Keras 3在性能上的提升可能会影响未来大型模型的选择和开发趋势,而谷歌在这方面的投入和努力已经开始得到了回报。

随着技术的不断进步,开发者和研究人员可能会更倾向于选择这些表现更佳的平台来构建和训练他们的模型。

原文和模型


【原文链接】 阅读原文 [ 1303字 | 6分钟 ]
【原文作者】 新智元
【摘要模型】 gpt-4
【摘要评分】 ★★☆☆☆

© 版权声明

相关文章

暂无评论

暂无评论...