Weights & Biases (wandb)
Weights & Biases is a powerful platform for experiment tracking, model management, and collaboration in machine learning projects. It helps you track your experiments, visualize results, and share findings with your team.
Getting Started
Installation
Basic Setup
import wandb
# Initialize a new run
wandb.init(project="my-project", name="experiment-1")
# Your training code here
# ...
# Finish the run
wandb.finish()
Tracking Experiments
Logging Metrics
import wandb
import random
wandb.init(project="optimization-example")
# Simulate a training loop
for epoch in range(100):
# Simulate training metrics
loss = random.uniform(0.1, 1.0) * (0.99 ** epoch) # Decreasing loss
accuracy = min(0.99, random.uniform(0.5, 1.0) * (1.01 ** epoch)) # Increasing accuracy
# Log metrics
wandb.log({
"epoch": epoch,
"loss": loss,
"accuracy": accuracy,
"learning_rate": 0.001 * (0.95 ** (epoch // 10))
})
print(f"Epoch {epoch}: Loss = {loss:.4f}, Accuracy = {accuracy:.4f}")
wandb.finish()
Configuration Tracking
import wandb
# Define hyperparameters
config = {
"learning_rate": 0.001,
"batch_size": 32,
"epochs": 100,
"model": "resnet50",
"optimizer": "adam"
}
# Initialize with config
wandb.init(
project="image-classification",
name="resnet50-experiment",
config=config
)
# Access config during training
lr = wandb.config.learning_rate
batch_size = wandb.config.batch_size
# Your training code using the config values
print(f"Training with learning rate: {lr}, batch size: {batch_size}")
wandb.finish()
Advanced Logging Features
import wandb
import matplotlib.pyplot as plt
import numpy as np
wandb.init(project="advanced-logging")
# Log images
fig, ax = plt.subplots()
x = np.linspace(0, 10, 100)
y = np.sin(x)
ax.plot(x, y)
ax.set_title("Sine Wave")
# Log the matplotlib figure
wandb.log({"sine_wave_plot": wandb.Image(fig)})
plt.close(fig)
# Log histograms
data = np.random.randn(1000)
wandb.log({"data_distribution": wandb.Histogram(data)})
# Log tables
table = wandb.Table(
columns=["epoch", "loss", "accuracy"],
data=[
[1, 0.8, 0.6],
[2, 0.6, 0.7],
[3, 0.4, 0.85]
]
)
wandb.log({"results_table": table})
# Log custom charts
wandb.log({
"custom_scatter": wandb.plot.scatter(
table, "loss", "accuracy", title="Loss vs Accuracy"
)
})
wandb.finish()
Integration with ML Frameworks
PyTorch Integration
import torch
import torch.nn as nn
import wandb
wandb.init(project="pytorch-integration")
# Watch model gradients and parameters
model = nn.Linear(10, 1)
wandb.watch(model, log="all", log_freq=10)
# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
for epoch in range(100):
# Forward pass
outputs = model(torch.randn(32, 10))
loss = criterion(outputs, torch.randn(32, 1))
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Log metrics
wandb.log({"loss": loss.item()})
wandb.finish()
Keras Integration
import tensorflow as tf
from wandb.keras import WandbCallback
import wandb
wandb.init(project="keras-integration")
# Build model
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Train with wandb callback
model.fit(
x_train, y_train,
validation_data=(x_test, y_test),
epochs=10,
callbacks=[WandbCallback()]
)
wandb.finish()
Key Features
- Experiment Tracking: Automatically track metrics, hyperparameters, and system information
- Visualization: Rich visualizations including line plots, histograms, images, and custom charts
- Model Management: Version and store your models with automatic lineage tracking
- Collaboration: Share experiments and results with team members
- Hyperparameter Optimization: Built-in sweeps for automated hyperparameter tuning
- Artifacts: Track and version datasets, models, and other files
- Reports: Create interactive reports to document and share your findings
Weights & Biases streamlines the machine learning workflow by providing comprehensive experiment tracking and visualization tools, making it easier to iterate on models and collaborate with team members.