Awesome JAX-based Models
This appendix extends the preceding one by adding awesome JAX-based models.
We'll cover the following...
There are also some ready-made models (and projects) developed in JAX. This list will be extremely helpful for someone doing research or some data science task.
Models and projects
JAX
- Fourier Feature Networks: Official implementation of Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains.
- kalman-jax: Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing.
- GPJax: Gaussian processes in JAX.
- jaxns: Nested sampling in JAX.
- Amortized Bayesian Optimization: Code related to Amortized Bayesian Optimization over Discrete Spaces.
- Accurate Quantized Training: Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax.
- BNN-HMC: Implementation for the paper What Are Bayesian Neural Network Posteriors Really Like?.
- JAX-DFT: One-dimensional density functional theory (DFT) in JAX, with implementation of Kohn-Sham