JASMINE
JAX Accelerated Statistical Models and Integrated Neural Engine
JASMINE accelerates statistical models with a scikit-learn-like API backed by GPU and TPU-optimized JAX primitives.
Overview
JASMINE is a high-performance machine learning library built on top of JAX. It marries the ergonomics of familiar scikit-learn APIs with the benefits of auto-differentiation, accelerator support, and composable pipelines that compile to lightning-fast kernels.
Features
JIT-compiled models
Linear and logistic regression layers compile down to GPU/TPU kernels for training and inference speedups.
Multiple optimizers
Fine-tune model convergence with SGD, Momentum, and Adam optimizers that leverage adaptive learning rates.
Advanced regularization
Combine L1, L2, and Elastic Net penalties to control overfitting while keeping models interpretable.
Data preprocessing
Accelerated data scalers and transformation utilities keep preprocessing on the accelerator path.
Sklearn-compatible API
Drop-in friendly API patterns mean teams can adopt JASMINE without relearning tooling.
Automatic differentiation
Harness the full JAX autodiff stack for gradients, Jacobians, and custom training loops.
Note: The roadmap includes additional models, loss functions, preprocessing utilities, and evaluation metrics as the project matures.
Want to learn more about JASMINE or contribute? Check out the documentation or the GitHub repository.