AI-powered learning
Save this course
Introduction to JAX and Deep Learning
Discover the power of JAX in deep learning. Gain insights into its ecosystem and learn about linear algebra, pseudo-random number generation, and optimization algorithms for cleaner, structured coding.
4.8
53 Lessons
2h 30min
Join 2.9 million developers at
Join 2.9 million developers at
LEARNING OBJECTIVES
- Learn the basics of JAX
- Learn how to apply Autograd
- Use auto vectorization for batching
- Use Haiku and Flax for implementing neural networks
- Cover Optax and overview of common optimization algorithms in deep learning
- Use Chex for testing JAX programs
- Learn the basics of applied linear algebra
- Learn random variables theory and probability distributions
- Learn pseudo-random number generation
- Cover the basics of optimal transport
Learning Roadmap
1.
Introduction
Introduction
Get familiar with JAX, a powerful library for deep learning and numerical computing.
2.
JAX Programming Model
JAX Programming Model
Walk through JAX's programming model, including pure functions, JIT, jaxpr, and autodiff.
3.
Linear Algebra
Linear Algebra
15 Lessons
15 Lessons
Explore the fundamental concepts of vectors, matrices, multivariate calculus, and convolutions in deep learning.
4.
Random Variables and Distributions
Random Variables and Distributions
7 Lessons
7 Lessons
Grasp the fundamentals of random variables, distributions, PRNGs, and divergence measures in JAX.
5.
JAX Ecosystem
JAX Ecosystem
14 Lessons
14 Lessons
Take a closer look at the tools and libraries within the JAX ecosystem for deep learning.
6.
Appendix
Appendix
6 Lessons
6 Lessons
Focus on installation steps, notable JAX libraries, models, vector calculus, common errors, and key terms.
Certificate of Completion
Showcase your accomplishment by sharing your certificate of completion.
Complete more lessons to unlock your certificate
Developed by MAANG Engineers
ABOUT THIS COURSE
JAX is a Python library designed for high-performance ML research. It is a powerful numerical computing library, just like Numpy, but with some key improvements.
In this course, you will learn all about JAX and its ecosystem of libraries (Haiku, Jraph, Chex, Flax, Optax). Addressing a wide range of audiences, you will cover several topics including linear algebra, random variables theory, pseudo-random number generation, and optimization algorithms.
By the end of this course, you will have a new set of skills that will make deep learning programming more intuitive, structured, and clean.
ABOUT THE AUTHOR
Khayyam Hashmi
Computer scientist and Generative AI and Machine Learning specialist. VP of Technical Content @ educative.io.
Trusted by 2.9 million developers working at companies
A
Anthony Walker
@_webarchitect_
E
Evan Dunbar
ML Engineer
S
Software Developer
Carlos Matias La Borde
S
Souvik Kundu
Front-end Developer
V
Vinay Krishnaiah
Software Developer
Built for 10x Developers
No Passive Learning
Learn by building with project-based lessons and in-browser code editor


Personalized Roadmaps
The platform adapts to your strengths & skills gaps as you go


Future-proof Your Career
Get hands-on with in-demand skills


AI Code Mentor
Write better code with AI feedback, smart debugging, and "Ask AI"




MAANG+ Interview Prep
AI Mock Interviews simulate every technical loop at top companies


Free Resources