关于ZAKER 融媒体解决方案 合作 加入

DeepMind 发布神经网络、强化学习库,网友:推动 JAX 发展

量子位 02-21

DeepMind 发布神经网络、强化学习库,网友:推动 JAX 发展

DeepMind 今日发布了HaikuRLax两个库,都是基于 JAX。

JAX 由谷歌提出,是 TensorFlow 的简化库。结合了针对线性代数的编译器 XLA,和自动区分本地 Python 和 Numpy 代码的库 Autograd,在高性能的机器学习研究中使用。

而此次发布的两个库,分别针对神经网络强化学习,大幅简化了 JAX 的使用。

Haiku 是基于 JAX 的神经网络库,允许用户使用熟悉的面向对象程序设计模型,可完全访问 JAX 的纯函数变换。

RLax 是 JAX 顶层的库,它提供了用于实现增强学习代理的有用构件。

有意思的是,Reddit 网友惊奇的发现 Haiku 这个库的名字,竟然不以 "ax" 结尾。

当然,也有网友对这两个库表示了肯定:

毫无疑问,对 JAX 起到了推动作用。

那么,我们就来看下 Haiku 和 RLex 的庐山真面目吧。

Haiku

Haiku 是 JAX 的神经网络库,它允许用户使用熟悉的面向对象编程模型,同时允许完全访问 JAX 的纯函数转换。

它提供了两个核心工具:模块抽象 hk.Module,和一个简单的函数转换 hk.transform。

hk.Module 是 Python 对象,包含对其自身参数、其他模块和对用户输入应用函数方法的引用。

hk.transform 允许完全访问 JAX 的纯函数转换。

其实,在 JAX 中有许多神经网络库,那么 Haiku 有什么特别之处呢?有 5 点。

1、Haiku 已经由 DeepMind 的研究人员进行了大规模测试

DeepMind 相对容易地在 Haiku 和 JAX 中复制了许多实验。其中包括图像和语言处理的大规模结果、生成模型和强化学习。

2、Haiku 是一个库,而不是一个框架

它的设计是为了简化一些具体的事情,包括管理模型参数和其他模型状态。可以与其他库一起编写,并与 JAX 的其他部分一起工作。

3、Haiku 并不是另起炉灶

它建立在 Sonnet 的编程模型和 API 之上,Sonnet 是 DeepMind 几乎普遍采用的神经网络库。它保留了 Sonnet 用于状态管理的基于模块的编程模型,同时保留了对 JAX 函数转换的访问。

4、过渡到 Haiku 是比较容易的

通过精心的设计,从 TensorFlow 和 Sonnet,过渡到 JAX 和 Haiku 是比较容易的。除了新的函数 ( 如 hk.transform ) ,Haiku 的目的是 Sonnet 2 的 API。

5、Haiku 简化了 JAX

它提供了一个处理随机数的简单模型。在转换后的函数中,hk.next_rng_key ( ) 返回一个唯一的 rng 键。

那么,该如何安装 Haiku呢?

Haiku 是用纯 Python 编写的,但是通过 JAX 依赖于 c++ 代码。

首先,按照下方链接中的说明,安装带有相关加速器支持的 JAX。

https://github.com/google/jax#installation

然后,只需要一句简单的 pip 命令就可以完成安装。

$ pip install git+https://github.com/deepmind/haiku

接下来,是一个神经网络和损失函数的例子。

import haiku as hk

import jax.numpy as jnp

def softmax_cross_entropy ( logits, labels ) :

one_hot = hk.one_hot ( labels, logits.shape [ -1 ] )

return -jnp.sum ( jax.nn.log_softmax ( logits ) * one_hot, axis=-1 )

def loss_fn ( images, labels ) :

model = hk.Sequential ( [

hk.Linear ( 1000 ) ,

jax.nn.relu,

hk.Linear ( 100 ) ,

jax.nn.relu,

hk.Linear ( 10 ) ,

] )

logits = model ( images )

return jnp.mean ( softmax_cross_entropy ( logits, labels ) )

loss_obj = hk.transform ( loss_fn )

RLax

它所提供的操作和函数不是完整的算法,而是强化学习特定数学操作的实现。

RLax 的安装也非常简单,一个 pip 命令就可以搞定。

pip install git+git://github.com/deepmind/rlax.git

使用 JAX 的 jax.jit 函数,所有的 RLax 代码可以不同的硬件上编译。

RLax 需要注意的是它的命名规则。

许多函数在连续的时间步长中考虑策略、操作、奖励和值,以便计算它们的输出。在这种情况下,后缀 _t 和 tm1 通常是为了说明每个输入是在哪个步骤上生成的,例如:

q_tm1:转换的源状态中的操作值。

a_tm1:在源状态下选择的操作。

r_t:在目标状态下收集的结果奖励。

q_t:目标状态下的操作值。

Haiku 和 RLax 都已在 GitHub 上开源,有兴趣的读者可从 " 传送门 " 的链接访问。

传送门

Haiku:

https://github.com/deepmind/haiku

RLax:

https://github.com/deepmind/rlax

3 期图像处理系列课程开始报名了 ~

2.27 第一期课程,来自 NVIDIA 开发者社区的何琨老师,将带领大家学习如何利用 NVIDIA 迁移式学习工具包实现实时目标检测。

欢迎大家扫下图二维码报名,记得备注 " 英伟达 " 哦 ~

直播报名 | 图像与视频处理系列课程

新年福利 | 关注 AI 发展新动态

内参新升级!拓展优质人脉,获取最新 AI 资讯 & 论文教程,欢迎加入 AI 内参社群一起学习 ~

量子位 QbitAI · 头条号签约作者

' ' 追踪 AI 技术和产品新动态

喜欢就点「在看」吧 !

以上内容由"量子位"上传发布 查看原文

觉得文章不错,微信扫描分享好友

扫码分享