Keras: Deep Learning for humans (original) (raw)
A superpower for ML developers
Keras is a deep learning API designed for human beings, not machines. Keras focuses on debugging speed, code elegance & conciseness, maintainability, and deployability. When you choose Keras, your codebase is smaller, more readable, easier to iterate on.
inputs = keras.Input(shape=(32, 32, 3))
x = layers.Conv2D(32, 3, activation="relu")(inputs)
x = layers.Conv2D(64, 3, activation="relu")(x)
residual = x = layers.MaxPooling2D(3)(x)
x = layers.Conv2D(64, 3, padding="same")(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(64, 3, padding="same")(x)
x = layers.Activation("relu")(x)
x = x + residual
x = layers.Conv2D(64, 3, activation="relu")(x)
x = layers.GlobalAveragePooling2D()(x)
outputs = layers.Dense(10, activation="softmax")(x)
model = keras.Model(inputs, outputs, name="mini_resnet")
keras.utils.plot_model(model, "mini_resnet.png")
model.fit(dataset, epochs=10)
causal_lm = keras_hub.models.CausalLM.from_preset(
"gemma2_instruct_2b_en",
dtype="float16",
)
prompt = """<start_of_turn>user
Write python code to print the first 100 primes.
<end_of_turn>
<start_of_turn>model
"""
text_output = causal_lm.generate(prompt, max_length=512)
text_to_image = keras_hub.models.TextToImage.from_preset(
"stable_diffusion_3_medium",
dtype="float16",
)
prompt = "Astronaut in a jungle, detailed"
image_output = text_to_image.generate(prompt)
Welcome to multi-framework machine learning
With its multi-backend approach, Keras gives you the freedom to work with JAX, TensorFlow, and PyTorch. Build models that can move seamlessly across these frameworks and leverage the strengths of each ecosystem.
inputs = keras.Input(shape=(28, 28, 1))
x = inputs
x = layers.Conv2D(16, 3, activation="relu")(x)
x = layers.Conv2D(32, 3, activation="relu")(x)
x = layers.MaxPooling2D(3)(x)
x = layers.Conv2D(32, 3, activation="relu")(x)
x = layers.Conv2D(16, 3, activation="relu")(x)
x = layers.GlobalMaxPooling2D()(x)
x = layers.Dropout(0.5)
outputs = layers.Dense(10)
model = keras.Model(inputs, outputs)
model.summary()
The Functional API
Starting from the beginning and learn how to build models using the functional building pattern.
model.compile(
optimizer="rmsprop",
loss="categorical_crossentropy",
metrics=["accuracy"],
)
history = model.fit(
x_train,
y_train,
batch_size=64,
epochs=2,
validation_data=(x_val, y_val),
)
Training & evaluation with the built-in methods
Train and evaluate your model using model.fit(...).
class MLPBlock(keras.layers.Layer):
def __init__(self):
super().__init__()
self.dense_1 = layers.Dense(32)
self.dense_2 = layers.Dense(32)
self.dense_3 = layers.Dense(1)
def call(self, inputs):
x = self.dense_1(inputs)
x = keras.activations.relu(x)
x = self.dense_2(x)
x = keras.activations.relu(x)
return self.dense_3(x)
Making new layers and models via subclassing
Learn how to customize your model via subclassing Keras layers.
KerasHub
The KerasHub library provides Keras 3 implementations of popular model architectures, paired with a collection of pretrained checkpoints available on Kaggle Models. Models can be used for both training and inference, on any of the TensorFlow, JAX, and PyTorch backends.
Computer vision
Take a look at our examples for doing image classification, object detection, video processing, and more.
Natural Language Processing
We also have many guides for doing NLP including text classification, machine translation, and language modeling.
Generative Deep Learning
Get started with generative deep learning with out wealth of guides involving state-of-the-art diffusion models, GANs, and transformer models.
Trusted for research and production
Keras is used by CERN, NASA, NIH, and many more scientific organizations around the world (and yes, Keras is used at the Large Hadron Collider). Keras is used by Waymo to power self-driving vehicles. Keras partners with Kaggle and HuggingFace to meet ML developers in the tools they use daily.