Network Security Internet Technology Development Database Servers Mobile Phone Android Software Apple Software Computer Software News IT Information

In addition to Weibo, there is also WeChat

Please pay attention

WeChat public account

Shulou

​ Jax adds new libraries to the ecology: DeepMind open source Haiku, RLax

2025-04-05 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Internet Technology >

Share

Shulou(Shulou.com)06/03 Report--

2020-02-25 10:51:05

Machine Heart report

Participation: Yiming

Jax is an excellent code base, which can differentiate automatically while carrying out scientific calculations, as well as the performance acceleration of GPU and TPU. But the ecology of Jax is not perfect, and the users are much less than TF and PyTorch. Recently, DeepMind has opened up two new Jax-based libraries, injecting new vitality into the ecology.

Jax is an open source scientific computing library of Google, which can perform automatic differentiation between Python programs and NumPy operations, and can run on GPU and TPU with high performance. There are many excellent open source projects based on Jax, such as Trax and so on. Recently, DeepMind has opened up two new Jax-based machine learning libraries, Haiku and RLax, both of which have their own characteristics and are of great significance to enrich the framework of the deep learning community and enhance the experience of researchers and developers.

Haiku: https://github.com/deepmind/haikuRLax: https://github.com/deepmind/rlax

Haiku: object-oriented development on Jax

The first thing to note is Haiku, a deep learning code base for Jax, developed by the Sonnet author, a team of Google neural network libraries.

Why use Haiku? This is because it supports the considerable advantages of Jax,Jax in terms of flexibility and performance. On the other hand, Jax itself is functional, which is different from object-oriented user habits. Therefore, with Haiku, users can do object-oriented development on Jax.

In addition, Haiku's API and programming model are based on Sonnet, so users who have used Sonnet can get started quickly. The project authors also say that Sonnet is to TensorFlow what Haiku is to Jax.

At present, Haiku has released the Alpha version, which is completely open source. The author of the project welcomes suggestions from users.

How does Haiku interact with Jax

Haiku is mainly divided into two modules: hk.Modules and hk.transform. The following will be described separately.

Hk.Modules is a Python object that holds references to parameters, other modules, and methods (references).

Hk.transform is responsible for converting the object-oriented module into pure functional code, and then let jax.jit, jax.grad, jax.pmap in jax deal with it, so as to achieve compatibility with Jax components.

The function of Haiku

Haiku can accomplish many tasks that need to be done in machine learning. The related functions and codes are as follows:

Customize your module

In Haiku, similar to TF2.0 and PyTorch, you can customize the module as a subclass of hk.Module. For example, customize a linear layer:

Class MyLinear (hk.Module): def _ init__ (self, output_size, name=None): super (MyLinear, self). _ _ init__ (name=name) self.output_size = output_size def _ call__ (self, x): J, k = x.shape [- 1], self.output_size w_init = hk.initializers.TruncatedNormal (1. / np.sqrt (j)) w = hk.get_parameter ("w", shape= [j, k], dtype=x.dtype Init=w_init) b = hk.get_parameter ("b", shape= [k], dtype=x.dtype, init=jnp.zeros) return jnp.dot (x, w) + b

As you can see, the code of Haiku is very similar to that of TensorFlow, but you can see that methods including numpy can also be defined in the module. The advantage of Haiku is that it is not a closed framework, but a code base, so you can call other libraries and methods in the process of defining modules.

Once the linear layer is defined, we want to try automatic differentiation:

Def forward_fn (x): model = MyLinear (10) return model (x) # Turn `forward_ fn`into an object with `init`and `apply` methods.forward = hk.transform (forward_fn) x = jnp.ones ([1,1]) # When we run `roomd.init`, Haiku will run `forward (x) `and collect initial# parameter values. Haiku requires you pass a RNG key to `init`, since parameters# are typically initialized randomly:key = hk.PRNGSequence (42) params = forward.init (next (key), x) # When we run `roomd.apply`, Haiku will run `forward (x) `and inject parameter# values from the `params` that are passed as the first argument. We do not require# an RNG key by default since models are deterministic. You can (of course!) Change# this using `hk.transform (f, apply_rng=True) `if you prefer:y = forward.apply (params, x)

As you can see here, after defining the module and the forward propagation function, you can use hk.transform (forward_fn) to convert this object-oriented method into the underlying functional code of Jax for processing, so you don't have to worry about the underlying computing problems. In addition, the code here is more concise than TensorFlow.

Non-training state

Sometimes we want to maintain the state of some internal parameters during training, which is very easy to achieve on Haiku.

Def forward (x, is_training): net = hk.nets.ResNet50 (1000) return net (x, is_training) forward = hk.transform_with_state (forward) # The `init` function now returns parameters * * and** state. State contains# anything that was created using `hk.set_ state`. The structure is the same as# params (e.g. It is a per-module mapping of named values). Params, state = forward.init (rng, x, is_training=True) # The apply function now takes both params * * and** state. Additionally it will# return updated values for state. In the resnet example this will be the# updated values for moving averages used in the batch norm layers.logits, state = forward.apply (params, state, rng, x, is_training=True)

As shown above, only two lines of code are needed to set it up.

Joint distributed training with jax.pmap

Because all the code will be converted to Jax functions, so they and jax.pmap. Is fully compatible. This shows that we can use jax.pmap for distributed computing.

The following is the distributed acceleration code for data segmentation. First, let's define the model and training steps:

Def loss_fn (inputs, labels): logits = hk.nets.MLP ([8,4,2]) (x) return jnp.mean (softmax_cross_entropy (logits, labels)) loss_obj = hk.transform (loss_fn) # Initialize the model on a single device.rng = jax.random.PRNGKey (428) sample_image, sample_label = next (input_dataset) params = loss_obj.init (rng, sample_image, sample_label)

Then set the parameters to be copied to all devices:

# Replicate params onto all devices.num_devices = jax.local_device_count () params = jax.tree_util.tree_map (lambda x: np.stack ([x] * num_devices), params)

Define the method of data batch and the method of parameter update:

Def make_superbatch (): "" Constructs a superbatch, i.e. One batch of data per device. "" # Get N batches, then split into list-of-images and list-of-labels. Superbatch = [next (input_dataset) for _ in range (num_devices)] superbatch_images, superbatch_labels = zip (* superbatch) # Stack the superbatches to be one array with a leading dimension, rather than # a python list. This is what `jax.pmap` expects as input. Superbatch_images = np.stack (superbatch_images) superbatch_labels = np.stack (superbatch_labels) return superbatch_images, superbatch_labelsdef update (params, inputs, labels, axis_name='i'): "" Updates params based on performance on inputs and labels. Grads = jax.grad (loss_obj.apply) (params, inputs, labels) # Take the mean of the gradients across all data-parallel replicas. Grads = jax.lax.pmean (grads, axis_name) # Update parameters using SGD or Adam or... New_params = my_update_rule (params, grads) return new_params

Finally, you can start distributed computing:

# Run several training updates.for _ in range (10): superbatch_images, superbatch_labels = make_superbatch () params = jax.pmap (update, axis_name='i') (params, superbatch_images, superbatch_labels)

There is also a reinforcement learning library on RLax:Jax.

In addition to the impressive Haiku, DeepMind also has open source RLax--, an Jax-based intensive learning library.

Compared with Haiku,RLax, it is specifically aimed at reinforcement learning. The project authors believe that although the operators and functions in reinforcement learning are not complete algorithms, if you need to build an agent based entirely on functions, you need specific mathematical operators.

Therefore, functional Jax has become a good choice. After some development on Jax, you can have a dedicated reinforcement learning library. RLax currently has little information, but the project has provided a sample code: using RLax to build and train Q-learning models.

The code is as follows. First, build a basic reinforcement learning model using Haiku:

Def build_network (num_actions: int)-> hk.Transformed: def Q (obs): flatten = lambda x: jnp.reshape (x, (- 1,)) network = hk.Sequential ([flatten, nets.MLP ([FLAGS.hidden_units, num_actions])]) return network (obs) return hk.transform (Q)

Set the method of training:

Def main_loop (unused_arg): env = catch.Catch (seed=FLAGS.seed) rng = hk.PRNGSequence (jax.random.PRNGKey (FLAGS.seed)) # Build and initialize Q-network. Num_actions = env.action_spec (). Num_values network = build_network (num_actions) sample_input = env.observation_spec (). Generate_value () net_params = network.init (next (rng), sample_input) # Build and initialize optimizer. Optimizer = optix.adam (FLAGS.learning_rate) opt_state = optimizer.init (net_params)

The following are combined with Jax to define policies, awards, and so on:

@ jax.jitdef policy (net_params, key, obs): "Sample action from epsilon-greedy policy." Q = network.apply (net_params, obs) a = rlax.epsilon_greedy (epsilon=FLAGS.epsilon). Sample (key, Q) return Q, a@jax.jitdef eval_policy (net_params, key, obs): "" Sample action from greedy policy. "" Q = network.apply (net_params, obs) return rlax.greedy (). Sample (key, Q) @ jax.jitdef update (net_params, opt_state, obs_tm1, a_tm1, rust, discount_t, Qintt): "Update network weights wrt Q-learning loss."def q_learning_loss (net_params, obs_tm1, a_tm1, rust, discount_t, Qintt): q_tm1 = network.apply (net_params) Obs_tm1) td_error = rlax.q_learning (q_tm1, a_tm1, rudt, discount_t, Qizht) return rlax.l2_loss (td_error) dloss_dtheta = jax.grad (q_learning_loss) (net_params, obs_tm1, a_tm1, rust, discount_t, qintt) updates, opt_state = optimizer.update (dloss_dtheta Opt_state) net_params = optix.apply_updates (net_params, updates) return net_params, opt_stateprint (f "Training agent for {FLAGS.train_episodes} episodes...")

As you can see, RLax's jax.jit-based approach has a good improvement in performance. More interestingly, the Haiku mentioned earlier is used to build the model, so it can be seen that the code bases based on Jax ecology are compatible.

As can be seen from the two open source code bases recently released by DeepMind, although the deep learning framework is still steadily developing, it is becoming more and more important for high-performance scientific computing. And excellent open source projects like Jax undoubtedly need more ecological support. This open source Haiku and RLax will undoubtedly consolidate the position of Jax and give further play to its excellent features.

Https://www.toutiao.com/i6797211131191493124/

Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.

Views: 0

*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.

Share To

Internet Technology

Wechat

© 2024 shulou.com SLNews company. All rights reserved.

12
Report