JASMINE

JAX Accelerated Statistical Models and Integrated Neural Engine

JASMINE accelerates statistical models with a scikit-learn-like API backed by GPU and TPU-optimized JAX primitives.

In Development2025JAXMachine LearningGPU/TPUPython

Overview

JASMINE is a high-performance machine learning library built on top of JAX. It marries the ergonomics of familiar scikit-learn APIs with the benefits of auto-differentiation, accelerator support, and composable pipelines that compile to lightning-fast kernels.

Features

JIT-compiled models

Linear and logistic regression layers compile down to GPU/TPU kernels for training and inference speedups.

Multiple optimizers

Fine-tune model convergence with SGD, Momentum, and Adam optimizers that leverage adaptive learning rates.

Advanced regularization

Combine L1, L2, and Elastic Net penalties to control overfitting while keeping models interpretable.

Data preprocessing

Accelerated data scalers and transformation utilities keep preprocessing on the accelerator path.

Sklearn-compatible API

Drop-in friendly API patterns mean teams can adopt JASMINE without relearning tooling.

Automatic differentiation

Harness the full JAX autodiff stack for gradients, Jacobians, and custom training loops.

Note: The roadmap includes additional models, loss functions, preprocessing utilities, and evaluation metrics as the project matures.

Want to learn more about JASMINE or contribute? Check out the documentation or the GitHub repository.