A Gentle Introduction to Deep Reinforcement Learning in JAX
This article has been published on Towards Data Science, read it here!
Intro
Recent progress in Reinforcement Learning (RL), such as Waymo’s autonomous taxis or DeepMind’s superhuman chess-playing agents, complement classical RL with Deep Learning components such as Neural Networks and Gradient Optimization methods.
Building on the foundations and coding principles introduced in one of my previous stories, we’ll discover and learn to implement Deep Q-Networks (DQN) and replay buffers to solve OpenAI’s CartPole environment. All of that in under a second using JAX!
For an introduction to JAX, vectorized environments, and Q-learning, please refer to the content of this story.
Our framework of choice for deep learning will be DeepMind’s Haiku library, which I recently introduced in the context of Transformers:
This article will cover the following sections:
- Why do we need Deep RL?
- Deep Q-Networks, theory and practice
- Replay Buffers
- Translating the CartPole environment to JAX
- The JAX way to write efficient training loops