Build an LLM Playground — Part 5: Build a Multi-modal Generation Agent
The fifth entry in the learn-by-doing AI engineer series. We cover the full landscape of visual generation — VAEs, GANs, auto-regressive models, and diffusion models — then go deep on text-to-image and text-to-video pipelines, and build a multi-modal generation agent.
Series: The AI Engineer Learning Path
This is Part 5 of a hands-on series designed to take you from zero to working AI engineer. Every post follows a learn-by-doing philosophy — we explain the theory, then you build something real.
| Part | Topic | Status |
|---|---|---|
| 1 | Build an LLM Playground | Complete |
| 2 | Customer Support Chatbot with RAGs & Prompt Engineering | Complete |
| 3 | "Ask-the-Web" Agent with Tool Calling | Complete |
| 4 | Deep Research with Reasoning Models | Complete |
| 5 | Multi-modal Generation Agent (this post) | Current |
In Parts 1-4, we worked exclusively with text — generating it, retrieving it, reasoning over it, searching for it. Now we cross into the visual world. Image and video generation has exploded over the past three years, moving from research curiosity to production capability. Understanding how these systems work is essential for any AI engineer.
By the end of this post, you'll understand every major approach to visual generation, know how diffusion models work at the math and code level, understand text-to-image and text-to-video pipelines end-to-end, and build a multi-modal generation agent that can create and iterate on images and videos from natural language descriptions.
Part I: Overview of Image and Video Generation
Before diving into diffusion models (the dominant paradigm today), let's understand the full landscape. Each approach represents a different answer to the same fundamental question: how do you learn to generate realistic images from a probability distribution?
The Core Problem
A 256x256 RGB image is a point in a 196,608-dimensional space (256 x 256 x 3 channels). The set of "realistic images" is a tiny manifold within that vast space — most random pixel configurations look like noise. Generative models learn to map from a simple distribution (like Gaussian noise) to this manifold of realistic images.
Random noise (easy to sample)
↓
[Generative Model]
↓
Realistic image (hard to sample directly)
Variational Autoencoders (VAE)
VAEs learn a compressed latent space where similar images are near each other, then generate images by decoding points from that space.
Architecture:
Image (256x256x3)
↓
[Encoder] → Latent vector z (e.g., 512-dimensional)
↓
[Decoder] → Reconstructed image (256x256x3)
The key insight: The encoder doesn't output a single point — it outputs a distribution (mean and variance). During training, we sample from this distribution, which forces the latent space to be smooth and continuous.
The reparameterization trick:
You can't backpropagate through a random sampling operation. The trick: instead of sampling z ~ N(mu, sigma), compute z = mu + sigma * epsilon where epsilon ~ N(0, 1). Now the randomness is in epsilon (which doesn't depend on model parameters), and gradients flow through mu and sigma.
import torch
import torch.nn as nn
import torch.nn.functional as F
class VAE(nn.Module):
"""Simple VAE for 28x28 grayscale images (e.g., MNIST)."""
def __init__(self, latent_dim=32):
super().__init__()
self.latent_dim = latent_dim
# Encoder: image → (mu, log_var)
self.encoder = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
)
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_log_var = nn.Linear(256, latent_dim)
# Decoder: z → image
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 28 * 28),
nn.Sigmoid(), # pixel values in [0, 1]
)
def encode(self, x):
h = self.encoder(x)
return self.fc_mu(h), self.fc_log_var(h)
def reparameterize(self, mu, log_var):
"""Sample z using the reparameterization trick."""
std = torch.exp(0.5 * log_var) # standard deviation
eps = torch.randn_like(std) # random noise
return mu + eps * std # z = mu + sigma * epsilon
def decode(self, z):
return self.decoder(z).view(-1, 1, 28, 28)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
reconstruction = self.decode(z)
return reconstruction, mu, log_var
def vae_loss(reconstruction, original, mu, log_var):
"""
VAE loss = Reconstruction loss + KL divergence.
Reconstruction: How well the decoder reproduces the input.
KL divergence: How close the latent distribution is to N(0, 1).
"""
# Binary cross-entropy for reconstruction
recon_loss = F.binary_cross_entropy(
reconstruction.view(-1, 28 * 28),
original.view(-1, 28 * 28),
reduction='sum'
)
# KL divergence: D_KL(q(z|x) || p(z))
# Closed-form for two Gaussians
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return recon_loss + kl_lossTraining the VAE:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Load MNIST
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=128, shuffle=True)
# Train
model = VAE(latent_dim=32)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(20):
total_loss = 0
for batch, _ in loader:
optimizer.zero_grad()
recon, mu, log_var = model(batch)
loss = vae_loss(recon, batch, mu, log_var)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataset)
print(f"Epoch {epoch + 1}: loss = {avg_loss:.2f}")
# Generate new images by sampling from the latent space
with torch.no_grad():
z = torch.randn(16, 32) # 16 random latent vectors
generated = model.decode(z)
# generated is 16 images of shape (1, 28, 28)VAE strengths and weaknesses:
| Strength | Weakness |
|---|---|
| Smooth, structured latent space | Generated images tend to be blurry |
| Principled probabilistic framework | Mode collapse — tends to average over possibilities |
| Fast generation (single forward pass) | Reconstruction quality limited by bottleneck |
| Good for learning representations | Not competitive with diffusion for image quality |
Where VAEs are used today: VAEs are critical components inside diffusion models. Stable Diffusion uses a VAE to compress images into a latent space, then runs diffusion in that latent space (much cheaper than pixel space). More on this later.
Generative Adversarial Networks (GANs)
GANs train two networks that compete against each other: a generator that creates fake images and a discriminator that tries to tell real from fake.
Random noise z ──→ [Generator G] ──→ Fake image
↓
[Discriminator D] ──→ "Real" or "Fake"
↑
Real image from dataset ─────────────────┘
The minimax game:
- The generator tries to fool the discriminator (make fakes that look real)
- The discriminator tries to catch the generator (distinguish real from fake)
- As they compete, both improve — the generator makes increasingly realistic images
import torch
import torch.nn as nn
class Generator(nn.Module):
"""Generate 28x28 images from random noise."""
def __init__(self, latent_dim=100):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, 28 * 28),
nn.Tanh(), # output in [-1, 1]
)
def forward(self, z):
return self.net(z).view(-1, 1, 28, 28)
class Discriminator(nn.Module):
"""Classify images as real or fake."""
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x):
return self.net(x)
# Training loop
latent_dim = 100
generator = Generator(latent_dim)
discriminator = Discriminator()
g_optimizer = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
criterion = nn.BCELoss()
for epoch in range(50):
for real_images, _ in loader:
batch_size = real_images.size(0)
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# --- Train Discriminator ---
# Real images should be classified as real
d_real_output = discriminator(real_images)
d_real_loss = criterion(d_real_output, real_labels)
# Fake images should be classified as fake
z = torch.randn(batch_size, latent_dim)
fake_images = generator(z).detach() # detach so gradients don't flow to G
d_fake_output = discriminator(fake_images)
d_fake_loss = criterion(d_fake_output, fake_labels)
d_loss = d_real_loss + d_fake_loss
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# --- Train Generator ---
# Generator wants discriminator to classify fakes as real
z = torch.randn(batch_size, latent_dim)
fake_images = generator(z)
g_output = discriminator(fake_images)
g_loss = criterion(g_output, real_labels) # "fool the discriminator"
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
print(f"Epoch {epoch + 1}: D_loss = {d_loss:.4f}, G_loss = {g_loss:.4f}")Key GAN variants:
| Variant | Key Idea | Notable For |
|---|---|---|
| DCGAN | Convolutional generator/discriminator | First stable GAN architecture |
| StyleGAN (1/2/3) | Style-based generator with progressive growing | State-of-the-art face generation |
| Pix2Pix | Paired image-to-image translation | Sketch → photo, satellite → map |
| CycleGAN | Unpaired image translation | Horse → zebra without paired training data |
| ProGAN | Progressive resolution increase during training | High-resolution generation |
GAN strengths and weaknesses:
| Strength | Weakness |
|---|---|
| Sharp, high-quality images | Training instability (mode collapse, oscillation) |
| Fast generation (single forward pass) | Hard to train — requires careful hypertuning |
| Excellent for specific domains (faces, art styles) | Poor diversity — may only learn subset of distribution |
| Real-time capable | No explicit density estimation |
Where GANs stand today: GANs dominated image generation from 2014 to 2021. Diffusion models have since surpassed them in quality and diversity for general image generation. However, GANs remain useful for real-time applications and domain-specific tasks where speed matters.
Auto-regressive Models
Auto-regressive models generate images one piece at a time, predicting each pixel (or token) conditioned on all previously generated pieces. This is the same approach as GPT for text — just applied to images.
The idea:
Text generation: "The" → "cat" → "sat" → "on" → "the" → "mat"
Image generation: pixel_1 → pixel_2 → pixel_3 → ... → pixel_65536
Generating 65,536 pixels one by one is slow. The breakthrough was VQ-VAE (Vector Quantized VAE): compress the image into a small grid of discrete tokens, then use a transformer to predict those tokens auto-regressively.
Image (256x256)
↓
[VQ-VAE Encoder] → Token grid (32x32 = 1024 tokens)
↓
[Transformer] → Predicts tokens one by one
↓
[VQ-VAE Decoder] → Reconstructed image (256x256)
VQ-VAE: turning images into tokens
import torch
import torch.nn as nn
import torch.nn.functional as F
class VectorQuantizer(nn.Module):
"""
Quantize continuous latent vectors to the nearest codebook entry.
This creates discrete "image tokens" — like words in a vocabulary.
"""
def __init__(self, num_embeddings=512, embedding_dim=64):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
# Codebook: a learned vocabulary of 512 latent vectors
self.codebook = nn.Embedding(num_embeddings, embedding_dim)
self.codebook.weight.data.uniform_(
-1.0 / num_embeddings, 1.0 / num_embeddings
)
def forward(self, z):
"""
z: (batch, channels, height, width) — continuous latent from encoder
Returns: quantized z, indices (the token IDs), commitment loss
"""
# Reshape: (B, C, H, W) → (B*H*W, C)
b, c, h, w = z.shape
z_flat = z.permute(0, 2, 3, 1).reshape(-1, c)
# Find nearest codebook entry for each vector
# Distance: ||z - e||^2 = ||z||^2 + ||e||^2 - 2*z*e^T
distances = (
z_flat.pow(2).sum(dim=1, keepdim=True)
+ self.codebook.weight.pow(2).sum(dim=1)
- 2 * z_flat @ self.codebook.weight.t()
)
indices = distances.argmin(dim=1) # closest codebook entry
# Look up the quantized vectors
z_quantized = self.codebook(indices).view(b, h, w, c).permute(0, 3, 1, 2)
# Losses
# Commitment loss: encoder output should be close to codebook entries
commitment_loss = F.mse_loss(z, z_quantized.detach())
# Codebook loss: codebook entries should be close to encoder output
codebook_loss = F.mse_loss(z.detach(), z_quantized)
# Straight-through estimator: copy gradients from z_quantized to z
z_quantized = z + (z_quantized - z).detach()
return z_quantized, indices.view(b, h, w), commitment_loss + 0.25 * codebook_lossAuto-regressive generation with a transformer:
Once you have the VQ-VAE, image generation becomes a sequence prediction problem:
# Pseudocode: auto-regressive image generation
def generate_image(transformer, vq_vae_decoder, text_prompt, grid_size=32):
"""
Generate an image token-by-token using a transformer,
then decode with VQ-VAE decoder.
"""
# Encode the text prompt
text_tokens = tokenize(text_prompt)
# Generate image tokens one by one
image_tokens = []
for i in range(grid_size * grid_size):
# Predict next token given text + all previous image tokens
context = text_tokens + image_tokens
logits = transformer(context)
next_token = sample_from(logits[-1]) # sample from the distribution
image_tokens.append(next_token)
# Decode tokens back to image
token_grid = reshape(image_tokens, (grid_size, grid_size))
image = vq_vae_decoder(token_grid)
return imageNotable auto-regressive image models:
| Model | Approach | Key Innovation |
|---|---|---|
| PixelCNN/PixelRNN | Predict pixels directly | First auto-regressive image models |
| VQ-VAE-2 | Hierarchical VQ-VAE + PixelCNN | Multi-scale token grids |
| DALL-E 1 | VQ-VAE + GPT-style transformer | First large-scale text-to-image |
| Parti (Google) | ViT-VQGAN + large transformer | Scaled auto-regressive T2I |
| LlamaGen | Llama architecture for image tokens | LLM architecture for image generation |
Auto-regressive strengths and weaknesses:
| Strength | Weakness |
|---|---|
| Unified framework with text (same architecture) | Slow generation (sequential token-by-token) |
| Naturally handles multi-modal sequences | Token grid limits detail (lossy compression) |
| Well-understood scaling laws from LLMs | Quadratic attention cost in long sequences |
| Easy to add conditioning (text, class labels) | Lower quality than diffusion at same compute |
Diffusion Models
Diffusion models are the current state of the art. They work by learning to reverse a gradual noising process — starting from pure noise and iteratively removing noise until a clean image emerges.
The forward process (adding noise):
Start with a clean image and gradually add Gaussian noise over T steps until you have pure noise.
Clean image → Slightly noisy → Noisier → ... → Pure Gaussian noise
x_0 x_1 x_2 x_T
The reverse process (removing noise):
Learn a neural network that predicts and removes the noise at each step, going from pure noise back to a clean image.
Pure noise → Slightly less noisy → ... → Clean image
x_T x_{T-1} x_0
Why this works: Adding noise is easy (just add random values). The neural network only needs to learn the small step of removing a little noise at each step — a much simpler task than generating an image from scratch in one shot.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleDiffusion:
"""
A minimal diffusion model implementation.
Learns to denoise images step by step.
"""
def __init__(self, num_timesteps=1000):
self.T = num_timesteps
# Noise schedule: beta controls how much noise is added at each step
# Linear schedule from 0.0001 to 0.02
self.betas = torch.linspace(1e-4, 0.02, num_timesteps)
# Precompute useful quantities
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
def forward_process(self, x_0, t, noise=None):
"""
Add noise to a clean image x_0 at timestep t.
q(x_t | x_0) = N(sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)
The "nice property": we can jump to any timestep directly
without iterating through all previous steps.
"""
if noise is None:
noise = torch.randn_like(x_0)
sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
# x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
x_t = sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise
return x_t, noise
def training_step(self, model, x_0):
"""
One training step:
1. Sample a random timestep
2. Add noise to the image
3. Predict the noise using the model
4. Compute the loss (predicted noise vs actual noise)
"""
batch_size = x_0.shape[0]
# Random timestep for each image in the batch
t = torch.randint(0, self.T, (batch_size,))
# Add noise
noise = torch.randn_like(x_0)
x_t, _ = self.forward_process(x_0, t, noise)
# Predict the noise
predicted_noise = model(x_t, t)
# Simple MSE loss: predicted noise vs actual noise
loss = F.mse_loss(predicted_noise, noise)
return loss
@torch.no_grad()
def sample(self, model, shape, device='cpu'):
"""
Generate an image by starting from pure noise
and iteratively denoising.
"""
# Start from pure noise
x = torch.randn(shape, device=device)
# Reverse process: denoise step by step
for t in reversed(range(self.T)):
t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
# Predict noise
predicted_noise = model(x, t_batch)
# Compute denoising step
alpha = self.alphas[t]
alpha_bar = self.alphas_cumprod[t]
beta = self.betas[t]
# Mean of the reverse distribution
mean = (1 / torch.sqrt(alpha)) * (
x - (beta / torch.sqrt(1 - alpha_bar)) * predicted_noise
)
# Add noise (except at the final step)
if t > 0:
noise = torch.randn_like(x)
sigma = torch.sqrt(beta)
x = mean + sigma * noise
else:
x = mean
return xThe noise prediction network (simplified U-Net):
class TimeEmbedding(nn.Module):
"""Encode the timestep t as a vector so the network knows 'how noisy' the input is."""
def __init__(self, dim):
super().__init__()
self.dim = dim
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim),
)
def forward(self, t):
# Sinusoidal position encoding (same as transformers)
half_dim = self.dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t.float().unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
return self.mlp(emb)
class SimpleUNet(nn.Module):
"""
Simplified U-Net for noise prediction.
Takes a noisy image and timestep, predicts the noise.
"""
def __init__(self, in_channels=1, base_channels=64, time_dim=256):
super().__init__()
self.time_embed = TimeEmbedding(time_dim)
# Encoder (downsampling)
self.down1 = nn.Sequential(
nn.Conv2d(in_channels, base_channels, 3, padding=1),
nn.GroupNorm(8, base_channels),
nn.GELU(),
)
self.down2 = nn.Sequential(
nn.Conv2d(base_channels, base_channels * 2, 3, stride=2, padding=1),
nn.GroupNorm(8, base_channels * 2),
nn.GELU(),
)
# Bottleneck
self.bottleneck = nn.Sequential(
nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1),
nn.GroupNorm(8, base_channels * 2),
nn.GELU(),
)
# Time conditioning: project time embedding to channel dimension
self.time_proj = nn.Linear(time_dim, base_channels * 2)
# Decoder (upsampling) with skip connections
self.up1 = nn.Sequential(
nn.ConvTranspose2d(base_channels * 4, base_channels, 2, stride=2),
nn.GroupNorm(8, base_channels),
nn.GELU(),
)
self.out = nn.Conv2d(base_channels * 2, in_channels, 1)
def forward(self, x, t):
# Time embedding
t_emb = self.time_embed(t)
# Encoder
h1 = self.down1(x) # (B, 64, H, W)
h2 = self.down2(h1) # (B, 128, H/2, W/2)
# Add time conditioning to bottleneck
t_proj = self.time_proj(t_emb).unsqueeze(-1).unsqueeze(-1)
h = self.bottleneck(h2) + t_proj # (B, 128, H/2, W/2)
# Decoder with skip connections
h = torch.cat([h, h2], dim=1) # (B, 256, H/2, W/2)
h = self.up1(h) # (B, 64, H, W)
h = torch.cat([h, h1], dim=1) # (B, 128, H, W)
return self.out(h) # (B, 1, H, W) — predicted noiseTraining the diffusion model:
# Train on MNIST
diffusion = SimpleDiffusion(num_timesteps=1000)
model = SimpleUNet(in_channels=1, base_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
for epoch in range(50):
total_loss = 0
for batch, _ in loader:
optimizer.zero_grad()
loss = diffusion.training_step(model, batch)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch + 1}: loss = {total_loss / len(loader):.4f}")
# Generate images
samples = diffusion.sample(model, shape=(16, 1, 28, 28))
# samples contains 16 generated 28x28 imagesWhy diffusion models won:
| Advantage over GANs | Advantage over VAEs | Advantage over Auto-regressive |
|---|---|---|
| Stable training (no adversarial dynamics) | Much sharper images (no blurriness) | Parallel generation (all pixels at once) |
| Better mode coverage (more diverse) | Higher quality reconstruction | Faster generation per image |
| Mathematically principled | Better latent space for manipulation | Better scaling with compute |
Summary: Generation Approaches
| Approach | How It Generates | Speed | Quality | Key Use Today |
|---|---|---|---|---|
| VAE | Decode from latent space | Very fast | Moderate (blurry) | Compression component in diffusion models |
| GAN | Generator fools discriminator | Very fast | High (sharp) | Real-time style transfer, face editing |
| Auto-regressive | Token by token | Slow | High | Multi-modal models (text + image in one model) |
| Diffusion | Iterative denoising | Moderate | Very high | DALL-E 3, Stable Diffusion, Midjourney, Sora |
Part II: Text-to-Image (T2I)
Text-to-image is the most commercially impactful application of generative models. Let's understand how modern T2I systems work end-to-end.
Data Preparation
T2I models need millions of image-text pairs. The quality and diversity of this data determines the quality of the model.
Major datasets:
| Dataset | Size | Source | Notes |
|---|---|---|---|
| LAION-5B | 5.85 billion pairs | Common Crawl (web scraping) | Largest open dataset, used by Stable Diffusion |
| LAION-Aesthetics | ~600M pairs | Filtered LAION-5B | Subset with high aesthetic scores |
| COYO-700M | 700M pairs | Web scraping | Korean-origin, multilingual |
| DataComp | 1.4B pairs | Curated from Common Crawl | Focus on data quality over quantity |
Data pipeline:
1. Scrape image-text pairs from the web (alt text, captions)
↓
2. Filter for quality:
- Remove images smaller than 256x256
- Remove duplicates (perceptual hashing)
- Remove NSFW content (CLIP-based classifier)
- Remove watermarked images
- Remove images with low aesthetic scores
↓
3. Clean captions:
- Remove boilerplate ("Click here to enlarge")
- Standardize formatting
- Optionally re-caption with a vision-language model (BLIP-2, LLaVA)
↓
4. Compute CLIP embeddings for efficient filtering and retrieval
↓
5. Store as WebDataset shards for efficient training I/O
Re-captioning with a vision model:
Web alt-text is often low quality ("IMG_2847.jpg", "photo", "banner image"). Modern pipelines use a vision-language model to generate better captions:
# Pseudocode: re-caption images with a VLM
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
vlm = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
def recaption_image(image):
"""Generate a detailed caption for an image."""
inputs = processor(image, "Describe this image in detail:", return_tensors="pt")
output = vlm.generate(**inputs, max_new_tokens=100)
caption = processor.decode(output[0], skip_special_tokens=True)
return caption
# Original alt-text: "photo"
# Re-captioned: "A golden retriever sitting on a wooden dock at sunset,
# with a calm lake and mountains in the background"Diffusion Architectures
Two main architectures are used for the denoising network:
U-Net Architecture (Stable Diffusion 1.x/2.x, DALL-E 2)
The U-Net is an encoder-decoder with skip connections. It was originally designed for medical image segmentation and adapted for diffusion.
Input (noisy latent)
↓
[Down Block 1] ──────────────────────────────→ [Up Block 3] (skip connection)
↓ ↑
[Down Block 2] ──────────────────────────────→ [Up Block 2] (skip connection)
↓ ↑
[Down Block 3] ──────────────────────────────→ [Up Block 1] (skip connection)
↓ ↑
[Middle Block (attention)]
Each block contains:
- Residual convolution layers
- Self-attention layers (for global context)
- Cross-attention layers (for text conditioning)
- Time embedding injection
Cross-attention for text conditioning:
This is how the model "understands" the text prompt. The image features attend to the text embeddings:
class CrossAttention(nn.Module):
"""Image features attend to text embeddings."""
def __init__(self, dim, context_dim, num_heads=8):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.to_q = nn.Linear(dim, dim) # queries from image
self.to_k = nn.Linear(context_dim, dim) # keys from text
self.to_v = nn.Linear(context_dim, dim) # values from text
self.to_out = nn.Linear(dim, dim)
def forward(self, x, context):
"""
x: image features (batch, seq_len, dim)
context: text embeddings (batch, text_len, context_dim)
"""
b, n, _ = x.shape
h = self.num_heads
q = self.to_q(x).view(b, n, h, self.head_dim).transpose(1, 2)
k = self.to_k(context).view(b, -1, h, self.head_dim).transpose(1, 2)
v = self.to_v(context).view(b, -1, h, self.head_dim).transpose(1, 2)
# Attention: image queries attend to text keys
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(b, n, -1)
return self.to_out(out)DiT Architecture (Stable Diffusion 3, FLUX, DALL-E 3)
Diffusion Transformer (DiT) replaces the U-Net with a pure transformer architecture. Instead of convolutions, the image is split into patches (like Vision Transformer) and processed with transformer blocks.
Noisy latent (e.g., 64x64x4)
↓
[Patchify] → sequence of patch tokens (e.g., 1024 tokens for 32x32 patches)
↓
[+ positional embeddings]
↓
[DiT Block 1] ← timestep embedding + text embeddings
↓
[DiT Block 2] ← timestep embedding + text embeddings
↓
... (24-40 blocks)
↓
[DiT Block N]
↓
[Unpatchify] → predicted noise (64x64x4)
DiT block with adaptive layer norm (adaLN-Zero):
class DiTBlock(nn.Module):
"""
A single DiT block. Uses adaptive layer norm to inject
timestep and text conditioning.
"""
def __init__(self, dim, num_heads, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim),
)
# adaLN-Zero: learn scale and shift from conditioning
# 6 parameters per block: gamma1, beta1, alpha1, gamma2, beta2, alpha2
self.adaln = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim),
)
def forward(self, x, conditioning):
"""
x: patch tokens (batch, seq_len, dim)
conditioning: timestep + text embedding (batch, dim)
"""
# Get adaptive layer norm parameters
adaln_params = self.adaln(conditioning).unsqueeze(1)
gamma1, beta1, alpha1, gamma2, beta2, alpha2 = adaln_params.chunk(6, dim=-1)
# Self-attention with adaLN
h = self.norm1(x) * (1 + gamma1) + beta1
h, _ = self.attn(h, h, h)
x = x + alpha1 * h
# MLP with adaLN
h = self.norm2(x) * (1 + gamma2) + beta2
h = self.mlp(h)
x = x + alpha2 * h
return xU-Net vs DiT:
| Aspect | U-Net | DiT |
|---|---|---|
| Architecture | CNN + attention | Pure transformer |
| Scaling | Hard to scale uniformly | Scales like LLMs (more layers, more heads) |
| Inductive bias | Strong spatial bias from convolutions | Minimal bias — learned from data |
| Used by | Stable Diffusion 1.x/2.x, DALL-E 2 | Stable Diffusion 3, FLUX, DALL-E 3, Sora |
| Compute | More efficient at lower scales | More efficient at larger scales |
| Trend | Legacy | Current standard |
Diffusion Training
The Forward Process (Adding Noise)
Given a clean image x_0, the forward process produces a noisy version x_t by adding Gaussian noise scaled by the timestep t:
q(x_t | x_0) = N(x_t; sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)
In plain English: x_t is a weighted mix of the original image and random noise, where the weight depends on the timestep.
At t=0, x_t is almost entirely the original image. At t=T, x_t is almost entirely noise.
# The forward process in one line:
x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noiseNoise schedules control how fast noise is added:
| Schedule | Formula | Characteristics |
|---|---|---|
| Linear | beta_t = beta_min + t/T * (beta_max - beta_min) | Simple, used in original DDPM |
| Cosine | alpha_bar_t = cos((t/T + s) / (1+s) * pi/2)^2 | Smoother, better for high resolution |
| Scaled linear | Adjusted for latent space diffusion | Used in Stable Diffusion |
The Backward Process (Learning to Denoise)
The model learns to predict the noise that was added, so it can be subtracted:
Training objective:
L = E[||noise - model(x_t, t, text)||^2]
In words: The loss is the mean squared error between
the actual noise that was added and the noise the model predicts.
Training algorithm (simplified):
Repeat:
1. Sample a clean image x_0 from the dataset
2. Sample a random timestep t ~ Uniform(1, T)
3. Sample random noise epsilon ~ N(0, I)
4. Compute noisy image: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
5. Predict noise: epsilon_pred = model(x_t, t, text_embedding)
6. Compute loss: L = ||epsilon - epsilon_pred||^2
7. Update model weights via gradient descent
What the model actually predicts — there are three equivalent formulations:
| Prediction Target | What the Model Outputs | Notes |
|---|---|---|
| Noise (epsilon) | The noise added at step t | Most common (DDPM, Stable Diffusion 1.x) |
| Clean image (x_0) | The denoised image directly | Used in some formulations |
| Velocity (v) | v = sqrt(alpha_bar_t) * epsilon - sqrt(1 - alpha_bar_t) * x_0 | Better numerical stability, used in SD 2.x+ |
Latent diffusion (what Stable Diffusion actually does):
Running diffusion in pixel space (256x256x3) is very expensive. Latent diffusion first compresses the image with a VAE, then runs diffusion in the latent space:
Pixel space diffusion: 256x256x3 = 196,608 dimensions
Latent space diffusion: 32x32x4 = 4,096 dimensions (48x cheaper!)
Image → [VAE Encoder] → Latent → [Diffusion in latent space] → Latent → [VAE Decoder] → Image
# Pseudocode: Latent Diffusion Training
def train_step(vae, diffusion_model, text_encoder, image, caption):
# 1. Encode image to latent space (VAE is frozen — pretrained)
with torch.no_grad():
latent = vae.encode(image).latent_dist.sample()
latent = latent * 0.18215 # scaling factor
# 2. Encode text caption
with torch.no_grad():
text_emb = text_encoder(tokenize(caption))
# 3. Standard diffusion training in latent space
t = torch.randint(0, T, (batch_size,))
noise = torch.randn_like(latent)
noisy_latent = forward_process(latent, t, noise)
# 4. Predict noise, conditioned on text
noise_pred = diffusion_model(noisy_latent, t, text_emb)
# 5. Loss
loss = F.mse_loss(noise_pred, noise)
return lossDiffusion Sampling
At inference time, we start from pure noise and iteratively denoise. The sampling strategy significantly affects speed and quality.
DDPM Sampling (Original)
The original approach: denoise one step at a time through all T timesteps (typically T=1000).
# DDPM sampling: 1000 steps (slow but high quality)
x = torch.randn(1, 4, 64, 64) # pure noise in latent space
for t in reversed(range(1000)):
noise_pred = model(x, t, text_embedding)
x = denoise_step(x, noise_pred, t) # one denoising step
image = vae.decode(x) # decode latent to pixel imageProblem: 1000 neural network forward passes per image is very slow.
DDIM Sampling (Faster)
DDIM (Denoising Diffusion Implicit Models) allows you to skip steps — denoise in 20-50 steps instead of 1000.
# DDIM sampling: skip steps for 50x speedup
timesteps = [999, 979, 959, ..., 19, 0] # 50 evenly spaced steps
x = torch.randn(1, 4, 64, 64)
for t in timesteps:
noise_pred = model(x, t, text_embedding)
x = ddim_step(x, noise_pred, t, t_prev) # deterministic step
image = vae.decode(x)DDIM is deterministic — given the same starting noise, you always get the same image. This enables useful features like interpolation between images.
Classifier-Free Guidance (CFG)
The most important trick in modern T2I. CFG amplifies the effect of the text prompt by comparing conditioned and unconditioned predictions:
noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
- guidance_scale = 1.0: No guidance (model's natural output)
- guidance_scale = 7.5: Standard (good balance of quality and diversity)
- guidance_scale = 15.0+: Very strong guidance (high prompt adherence, less diversity)
def guided_denoise_step(model, x_t, t, text_embedding, guidance_scale=7.5):
"""
Classifier-free guidance: run the model twice (with and without text),
then amplify the difference.
"""
# Unconditional prediction (empty text)
noise_uncond = model(x_t, t, empty_text_embedding)
# Conditional prediction (with text)
noise_cond = model(x_t, t, text_embedding)
# Guided prediction
noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
return noise_predDuring training, CFG requires randomly dropping the text conditioning (replacing it with an empty string) some percentage of the time (typically 10-20%). This teaches the model to generate both conditionally and unconditionally.
Modern samplers (DPM-Solver, UniPC, Euler, etc.) can produce good results in 20-30 steps by using higher-order ODE solvers.
Evaluation
How do we measure whether a T2I model is actually good? There are several complementary metrics.
| Metric | What It Measures | How It Works | Good Score |
|---|---|---|---|
| FID (Frechet Inception Distance) | Image quality + diversity | Compare statistics of generated vs real images in Inception feature space | Lower is better (good models: 5-15) |
| IS (Inception Score) | Image quality + diversity | How confidently Inception classifies generated images, and how diverse those classifications are | Higher is better (good models: 50-200+) |
| CLIP Score | Image-text alignment | Cosine similarity between CLIP embeddings of the image and text prompt | Higher is better (0-1 scale) |
| Aesthetic Score | Visual appeal | Trained predictor of human aesthetic preference | Higher is better (1-10 scale) |
| Human Evaluation | Overall quality | Humans compare and rate generated images | Gold standard but expensive |
# Computing CLIP score
import torch
from transformers import CLIPProcessor, CLIPModel
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def clip_score(image, text):
"""Measure how well the image matches the text description."""
inputs = clip_processor(text=[text], images=[image], return_tensors="pt")
outputs = clip_model(**inputs)
# Cosine similarity between image and text embeddings
image_emb = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
text_emb = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
score = (image_emb @ text_emb.T).item()
return score
# Example
score = clip_score(generated_image, "a cat wearing a top hat in a garden")
print(f"CLIP score: {score:.3f}")# Computing FID (using pytorch-fid)
# pip install pytorch-fid
# Generate 10,000 images and save to a directory
# Then compare against a directory of real images:
# python -m pytorch_fid path/to/real_images path/to/generated_imagesEvaluation pitfalls:
| Pitfall | Why It's a Problem |
|---|---|
| FID is sensitive to sample size | Need at least 10K-50K images for stable FID |
| IS doesn't measure text alignment | High IS just means sharp, diverse images |
| CLIP score can be gamed | Model could overfit to CLIP's biases |
| Single metrics are insufficient | Always use multiple metrics + human evaluation |
Part III: Text-to-Video (T2V)
Text-to-video extends diffusion to the temporal dimension. Instead of generating a single image, we generate a sequence of frames that are temporally coherent — objects move smoothly, lighting changes gradually, and the scene makes physical sense across time.
Why Video Is Harder Than Images
| Challenge | Why It's Hard |
|---|---|
| Temporal coherence | Each frame must be consistent with the previous frame — characters shouldn't teleport or morph |
| Data volume | A 4-second video at 24fps is 96 frames. At 256x256, that's 18.9M pixels per clip (vs 196K for a single image) |
| Motion understanding | The model must learn physics, object permanence, and natural motion from data |
| Compute | Training and inference costs scale linearly (or worse) with video length |
| Evaluation | Harder to measure — temporal quality, motion naturalness, scene consistency |
Latent Diffusion Modeling (LDM) and Compression Networks
Just as Stable Diffusion runs diffusion in a compressed latent space for images, video diffusion models use compression networks to reduce video to a manageable latent representation.
Image VAE vs Video VAE (3D VAE):
Image VAE:
Image (256x256x3) → [2D Encoder] → Latent (32x32x4)
Compression: spatial only (8x downsampling)
Video VAE (3D VAE):
Video (T x 256x256x3) → [3D Encoder] → Latent (T/4 x 32x32x4)
Compression: spatial (8x) AND temporal (4x)
The temporal compression is crucial: a 96-frame video becomes a 24-frame latent sequence, making diffusion tractable.
# Pseudocode: 3D VAE for video compression
class Video3DVAE(nn.Module):
"""
Compress video spatially AND temporally.
Uses 3D convolutions to handle the time dimension.
"""
def __init__(self, spatial_downsample=8, temporal_downsample=4):
super().__init__()
# 3D convolutions: (time, height, width)
self.encoder = nn.Sequential(
# Spatial + temporal downsampling
nn.Conv3d(3, 64, kernel_size=(3, 4, 4), stride=(1, 2, 2), padding=(1, 1, 1)),
nn.SiLU(),
nn.Conv3d(64, 128, kernel_size=(3, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
nn.SiLU(),
nn.Conv3d(128, 256, kernel_size=(3, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
nn.SiLU(),
nn.Conv3d(256, 4, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
# Output: (B, 4, T/4, H/8, W/8)
)
self.decoder = nn.Sequential(
# Reverse the compression
nn.ConvTranspose3d(4, 256, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.SiLU(),
nn.ConvTranspose3d(256, 128, kernel_size=(3, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
nn.SiLU(),
nn.ConvTranspose3d(128, 64, kernel_size=(3, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
nn.SiLU(),
nn.ConvTranspose3d(64, 3, kernel_size=(3, 4, 4), stride=(1, 2, 2), padding=(1, 1, 1)),
)
def encode(self, video):
"""video: (B, C, T, H, W) → latent: (B, 4, T/4, H/8, W/8)"""
return self.encoder(video)
def decode(self, latent):
"""latent: (B, 4, T/4, H/8, W/8) → video: (B, C, T, H, W)"""
return self.decoder(latent)Training the video VAE:
- Reconstruction loss (MSE or perceptual loss between input and decoded video)
- KL divergence loss (keep latent close to Gaussian)
- Temporal consistency loss (ensure decoded frames are smooth)
- Adversarial loss (optional — use a discriminator for sharper outputs)
Data Preparation for T2V
Video data is far messier and more expensive to process than image data.
Data pipeline:
1. Source collection
- Stock video libraries (Shutterstock, Pexels)
- Web-scraped video-text pairs
- Internal datasets
↓
2. Filtering
- Remove static videos (no motion)
- Remove slide shows and text-heavy videos
- Remove videos with excessive camera shake
- Scene-cut detection: split long videos into single-scene clips
- Motion scoring: compute optical flow, filter low/excessive motion
- Aesthetic filtering: score visual quality per frame
↓
3. Standardization
- Resize to target resolution (e.g., 512x512)
- Normalize frame rate (e.g., 24 fps)
- Fixed clip length (e.g., 4 seconds = 96 frames)
- Center crop or pad to square aspect ratio
↓
4. Captioning
- Use a video captioning model (e.g., InternVideo, LLaVA-Video)
- Generate both short captions ("a cat jumping on a table")
and detailed descriptions ("a tabby cat with orange fur leaps from
the floor onto a wooden kitchen table, knocking over a glass of water")
↓
5. Video latent caching
- Pre-encode all videos through the 3D VAE
- Store latent representations to disk
- This avoids re-encoding during training (major speedup)
Video latent caching:
# Pseudocode: pre-compute and cache video latents
def cache_video_latents(dataset, vae, output_dir):
"""
Pre-encode all training videos through the VAE and cache to disk.
This is a one-time cost that saves enormous compute during training.
"""
vae.eval()
for idx, (video, caption) in enumerate(dataset):
with torch.no_grad():
# video: (C, T, H, W) → latent: (4, T/4, H/8, W/8)
latent = vae.encode(video.unsqueeze(0)).squeeze(0)
# Save latent and caption
torch.save({
'latent': latent,
'caption': caption,
}, f"{output_dir}/sample_{idx:08d}.pt")
print(f"Cached {len(dataset)} video latents")
# During training, load latents instead of raw video:
class CachedVideoDataset:
def __init__(self, cache_dir):
self.files = sorted(glob(f"{cache_dir}/*.pt"))
def __getitem__(self, idx):
data = torch.load(self.files[idx])
return data['latent'], data['caption']DiT Architecture for Videos
Extending DiT to video means handling an additional time dimension. The key architecture choices:
Patch embedding for video:
Video latent: (T, H, W, C) → e.g., (24, 32, 32, 4)
↓
Patchify: split into (t, h, w) patches, e.g., (2, 2, 2)
↓
Tokens: (T/2 * H/2 * W/2) = 12 * 16 * 16 = 3,072 tokens
↓
Linear projection: each 3D patch → token vector of dimension d
Attention strategies for video:
Full 3D attention over all 3,072 tokens is prohibitively expensive (quadratic in sequence length). Modern video DiTs use factored attention:
Option A: Full 3D attention (prohibitive at scale)
Every token attends to every other token
Cost: O((T*H*W)^2) — too expensive for long videos
Option B: Factored attention (practical)
Alternate between:
1. Spatial attention: tokens within the same frame attend to each other
2. Temporal attention: tokens at the same spatial position attend across frames
Cost: O(T * (H*W)^2 + H*W * T^2) — much cheaper
class VideoTransformerBlock(nn.Module):
"""
DiT block for video: alternates spatial and temporal attention.
"""
def __init__(self, dim, num_heads):
super().__init__()
# Spatial self-attention (within each frame)
self.spatial_norm = nn.LayerNorm(dim)
self.spatial_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
# Temporal self-attention (across frames at same spatial position)
self.temporal_norm = nn.LayerNorm(dim)
self.temporal_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
# Cross-attention for text conditioning
self.cross_norm = nn.LayerNorm(dim)
self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
# Feed-forward
self.ff_norm = nn.LayerNorm(dim)
self.ff = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim),
)
def forward(self, x, text_emb, num_frames, h, w):
"""
x: (batch, num_frames * h * w, dim) — flattened video tokens
text_emb: (batch, text_len, dim) — text conditioning
"""
b = x.shape[0]
thw = num_frames * h * w
# 1. Spatial attention: reshape so each frame is a separate sequence
# (batch, T*H*W, dim) → (batch*T, H*W, dim)
x_spatial = x.view(b * num_frames, h * w, -1)
h_spatial = self.spatial_norm(x_spatial)
h_spatial, _ = self.spatial_attn(h_spatial, h_spatial, h_spatial)
x = x + h_spatial.view(b, thw, -1)
# 2. Temporal attention: reshape so each spatial position is a sequence across time
# (batch, T*H*W, dim) → (batch*H*W, T, dim)
x_temp = x.view(b, num_frames, h * w, -1).permute(0, 2, 1, 3).reshape(b * h * w, num_frames, -1)
h_temp = self.temporal_norm(x_temp)
h_temp, _ = self.temporal_attn(h_temp, h_temp, h_temp)
x = x + h_temp.reshape(b, h * w, num_frames, -1).permute(0, 2, 1, 3).reshape(b, thw, -1)
# 3. Cross-attention with text
h_cross = self.cross_norm(x)
h_cross, _ = self.cross_attn(h_cross, text_emb, text_emb)
x = x + h_cross
# 4. Feed-forward
x = x + self.ff(self.ff_norm(x))
return xLarge-Scale Training Challenges
Training video generation models is one of the most compute-intensive tasks in AI.
| Challenge | Scale | Mitigation |
|---|---|---|
| Compute | Sora reportedly trained on 10,000+ H100 GPUs | Model parallelism, gradient checkpointing, mixed precision |
| Memory | A single 4-second, 1080p video latent can be 100+ MB | Gradient accumulation, activation checkpointing, offloading |
| Data | Millions of high-quality video-text pairs | Aggressive filtering, synthetic captions, progressive training |
| Stability | Training can diverge at large scale | Learning rate warmup, gradient clipping, loss spike detection |
| Evaluation | No single metric captures video quality | Automated metrics + human eval, temporal quality metrics |
Progressive training strategy:
Instead of training at full resolution and duration from the start, gradually increase complexity:
Phase 1: Image pretraining
Train the DiT on images (single frame)
Resolution: 256x256
Duration: 10% of total training compute
Phase 2: Short, low-resolution video
Fine-tune on 16-frame, 256x256 video
The model learns basic temporal dynamics
Duration: 30% of total training compute
Phase 3: Longer, higher-resolution video
Fine-tune on 64-frame, 512x512 video
The model learns longer temporal consistency
Duration: 40% of total training compute
Phase 4: Full resolution
Fine-tune on 96+ frame, 1080p video
Duration: 20% of total training compute
Mixed-resolution training: Train on multiple resolutions simultaneously. Pack shorter/smaller videos into the same batch as longer/larger ones to maximize GPU utilization.
T2V Overall System
A complete text-to-video system has many components beyond the diffusion model itself:
┌──────────────────────────────────────────────────────────────────────────┐
│ Text-to-Video System │
│ │
│ "A golden retriever running through a field of sunflowers at sunset" │
│ ↓ │
│ ┌──────────────────────────────────┐ │
│ │ Text Encoder (e.g., T5-XXL) │ │
│ │ Converts text to embeddings │ │
│ └──────────────┬───────────────────┘ │
│ ↓ │
│ ┌──────────────────────────────────┐ │
│ │ Video DiT │ │
│ │ Iterative denoising in latent │ ← Noise schedule │
│ │ space with text conditioning │ ← CFG guidance │
│ │ (20-50 sampling steps) │ ← Sampler (DPM, Euler, etc.) │
│ └──────────────┬───────────────────┘ │
│ ↓ │
│ ┌──────────────────────────────────┐ │
│ │ 3D VAE Decoder │ │
│ │ Latent → pixel-space video │ │
│ └──────────────┬───────────────────┘ │
│ ↓ │
│ ┌──────────────────────────────────┐ │
│ │ Post-Processing │ │
│ │ - Frame interpolation (24→60fps) │ │
│ │ - Super-resolution (optional) │ │
│ │ - Temporal smoothing │ │
│ └──────────────┬───────────────────┘ │
│ ↓ │
│ Final video (4-10 seconds, 1080p, 24-60fps) │
└──────────────────────────────────────────────────────────────────────────┘
Key components:
| Component | Role | Examples |
|---|---|---|
| Text encoder | Convert prompt to embeddings that condition the diffusion model | T5-XXL, CLIP text encoder, or both |
| Video DiT | The core diffusion model — denoises video latents | Custom architecture per company |
| 3D VAE | Compress/decompress video to/from latent space | Trained separately |
| Sampler | Algorithm for stepping through the denoising process | DDIM, DPM-Solver++, Euler |
| Super-resolution | Upscale to final resolution | Separate diffusion model or upscaler |
| Frame interpolation | Increase frame rate | FILM, RIFE |
Notable T2V models:
| Model | Organization | Key Innovation |
|---|---|---|
| Sora | OpenAI | Scaling DiT to long, high-res video; strong physics understanding |
| Kling | Kuaishou | Competitive with Sora; publicly available |
| Runway Gen-3 | Runway | Motion control, camera movement specification |
| Stable Video Diffusion | Stability AI | Open-weight video diffusion model |
| CogVideoX | Tsinghua/Zhipu | Open-weight, strong temporal coherence |
| Veo 2 | Google DeepMind | High fidelity, long-form video |
| Wan | Alibaba | Open-weight, 14B parameter video DiT |
Part IV: Build a Multi-modal Generation Agent
Now let's build an agent that uses image generation models as tools. This agent can understand text descriptions, generate images, refine them based on feedback, and orchestrate a multi-step creative workflow.
Architecture
┌──────────────────────────────────────────────────────────────┐
│ Multi-modal Generation Agent │
│ │
│ User: "Create a logo for a coffee shop called 'Bean There'" │
│ ↓ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ LLM Orchestrator (Claude) │ │
│ │ │ │
│ │ 1. Analyze the request │ │
│ │ 2. Craft an optimized image generation prompt │ │
│ │ 3. Call image generation tool │ │
│ │ 4. Describe the result to the user │ │
│ │ 5. Iterate based on feedback │ │
│ └───────────┬───────────────────────────────────────┘ │
│ │ │
│ Tools: │ │
│ ┌───────────▼───────────┐ ┌──────────────────────┐ │
│ │ generate_image │ │ edit_image │ │
│ │ (DALL-E / SD API) │ │ (inpainting) │ │
│ └───────────────────────┘ └──────────────────────┘ │
│ ┌───────────────────────┐ ┌──────────────────────┐ │
│ │ analyze_image │ │ generate_video │ │
│ │ (vision model) │ │ (Runway / Kling API) │ │
│ └───────────────────────┘ └──────────────────────┘ │
└──────────────────────────────────────────────────────────────┘
Implementation
import anthropic
import openai
import json
import base64
import httpx
from pathlib import Path
# --- Clients ---
claude = anthropic.Anthropic()
oai = openai.OpenAI()
# --- Tool Definitions ---
tools = [
{
"name": "generate_image",
"description": (
"Generate an image from a text prompt using DALL-E. "
"The prompt should be detailed and descriptive. "
"Returns a URL to the generated image."
),
"input_schema": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Detailed description of the image to generate"
},
"size": {
"type": "string",
"enum": ["1024x1024", "1792x1024", "1024x1792"],
"description": "Image dimensions"
},
"style": {
"type": "string",
"enum": ["vivid", "natural"],
"description": "vivid for hyper-real/dramatic, natural for realistic"
}
},
"required": ["prompt"]
}
},
{
"name": "generate_image_sdxl",
"description": (
"Generate an image using Stable Diffusion XL via the Stability API. "
"Good for artistic styles, illustrations, and when you need more "
"control over the generation process."
),
"input_schema": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Detailed description of the image to generate"
},
"negative_prompt": {
"type": "string",
"description": "What to avoid in the image (e.g., 'blurry, low quality')"
},
"style_preset": {
"type": "string",
"enum": [
"photographic", "digital-art", "comic-book",
"fantasy-art", "analog-film", "cinematic"
],
"description": "Visual style preset"
}
},
"required": ["prompt"]
}
},
{
"name": "analyze_image",
"description": (
"Analyze an image using Claude's vision capabilities. "
"Can describe what's in the image, evaluate quality, "
"suggest improvements, or answer questions about it."
),
"input_schema": {
"type": "object",
"properties": {
"image_url": {
"type": "string",
"description": "URL of the image to analyze"
},
"question": {
"type": "string",
"description": "What to analyze or ask about the image"
}
},
"required": ["image_url", "question"]
}
},
{
"name": "refine_prompt",
"description": (
"Take a user's simple description and expand it into an "
"optimized prompt for image generation. Adds details about "
"lighting, composition, style, and technical parameters."
),
"input_schema": {
"type": "object",
"properties": {
"user_description": {
"type": "string",
"description": "The user's original, possibly brief description"
},
"style_preference": {
"type": "string",
"description": "Preferred style (photorealistic, illustration, etc.)"
}
},
"required": ["user_description"]
}
}
]
# --- Tool Implementations ---
def generate_image_dalle(prompt, size="1024x1024", style="vivid"):
"""Generate an image using DALL-E 3."""
response = oai.images.generate(
model="dall-e-3",
prompt=prompt,
size=size,
style=style,
n=1,
)
return json.dumps({
"image_url": response.data[0].url,
"revised_prompt": response.data[0].revised_prompt
})
def generate_image_sdxl(prompt, negative_prompt="", style_preset="photographic"):
"""Generate an image using Stability AI's SDXL API."""
response = httpx.post(
"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={"Authorization": f"Bearer {STABILITY_API_KEY}"},
files={"none": ""},
data={
"prompt": prompt,
"negative_prompt": negative_prompt or "blurry, low quality, distorted",
"style_preset": style_preset,
"output_format": "png",
}
)
# Save image and return path
path = f"generated_{hash(prompt)}.png"
Path(path).write_bytes(response.content)
return json.dumps({"image_path": path, "prompt_used": prompt})
def analyze_image(image_url, question):
"""Analyze an image using Claude's vision."""
response = claude.messages.create(
model="claude-sonnet-4-6",
max_tokens=1024,
messages=[{
"role": "user",
"content": [
{"type": "image", "source": {"type": "url", "url": image_url}},
{"type": "text", "text": question}
]
}]
)
return response.content[0].text
def refine_prompt(user_description, style_preference=""):
"""Expand a simple description into a detailed generation prompt."""
response = claude.messages.create(
model="claude-sonnet-4-6",
max_tokens=500,
messages=[{
"role": "user",
"content": f"""Expand this image description into a detailed prompt
optimized for AI image generation. Add specific details about:
- Composition and framing
- Lighting and atmosphere
- Color palette
- Style and medium
- Important details to include
User description: {user_description}
Style preference: {style_preference or "not specified"}
Return ONLY the expanded prompt, nothing else."""
}]
)
return response.content[0].text
def execute_tool(name, args):
"""Route tool calls to implementations."""
if name == "generate_image":
return generate_image_dalle(
args["prompt"],
args.get("size", "1024x1024"),
args.get("style", "vivid")
)
elif name == "generate_image_sdxl":
return generate_image_sdxl(
args["prompt"],
args.get("negative_prompt", ""),
args.get("style_preset", "photographic")
)
elif name == "analyze_image":
return analyze_image(args["image_url"], args["question"])
elif name == "refine_prompt":
return refine_prompt(
args["user_description"],
args.get("style_preference", "")
)
return f"Unknown tool: {name}"
# --- Agent ---
SYSTEM_PROMPT = """You are a multi-modal creative agent that helps users generate
and refine images and videos. You have access to multiple generation models.
Your workflow:
1. Understand what the user wants to create
2. Use refine_prompt to expand vague descriptions into detailed generation prompts
3. Choose the appropriate generation tool (DALL-E for photorealism, SDXL for artistic styles)
4. After generation, use analyze_image to evaluate the result
5. If the result needs improvement, iterate with adjusted prompts
6. Present the result to the user with a description of what was generated
Tips for great prompts:
- Be specific about composition, lighting, and style
- Include the medium (photograph, oil painting, 3D render, etc.)
- Specify camera angle if relevant (close-up, wide shot, aerial view)
- Include mood and atmosphere keywords
- For logos/designs, specify that text should be clear and readable
Always explain your creative decisions to the user."""
def creative_agent(user_request: str):
"""Run the multi-modal generation agent."""
messages = [{"role": "user", "content": user_request}]
for step in range(10): # max 10 tool-use steps
response = claude.messages.create(
model="claude-sonnet-4-6",
max_tokens=4096,
system=SYSTEM_PROMPT,
tools=tools,
messages=messages,
)
messages.append({"role": "assistant", "content": response.content})
if response.stop_reason == "tool_use":
tool_results = []
for block in response.content:
if block.type == "tool_use":
print(f" Tool: {block.name}({json.dumps(block.input)[:100]}...)")
result = execute_tool(block.name, block.input)
print(f" Result: {result[:200]}...")
tool_results.append({
"type": "tool_result",
"tool_use_id": block.id,
"content": result
})
messages.append({"role": "user", "content": tool_results})
else:
# Final response
text = "".join(
block.text for block in response.content if hasattr(block, "text")
)
return text
return "Agent reached maximum steps."
# --- Run it ---
if __name__ == "__main__":
result = creative_agent(
"Create a logo for a coffee shop called 'Bean There, Done That'. "
"The style should be warm and inviting, with a vintage feel."
)
print(result)Iterative Refinement Loop
The real power of an agent-based approach is iterative refinement. The agent can generate, analyze, critique, and regenerate:
def iterative_generation(
user_request: str,
max_iterations: int = 3
) -> list:
"""
Generate an image, analyze it, and iterate until satisfied.
Returns the full history of generations.
"""
history = []
# Initial prompt refinement
refined_prompt = refine_prompt(user_request)
print(f"Refined prompt: {refined_prompt[:100]}...")
for iteration in range(max_iterations):
print(f"\n--- Iteration {iteration + 1} ---")
# Generate
result = json.loads(generate_image_dalle(refined_prompt))
image_url = result["image_url"]
print(f"Generated image: {image_url[:80]}...")
# Analyze
analysis = analyze_image(
image_url,
f"Evaluate this image against the original request: '{user_request}'. "
f"Score it 1-10 on: accuracy to request, visual quality, composition. "
f"List specific issues that should be fixed. "
f"If it scores 8+ on all criteria, say 'APPROVED'."
)
print(f"Analysis: {analysis[:200]}...")
history.append({
"iteration": iteration + 1,
"prompt": refined_prompt,
"image_url": image_url,
"analysis": analysis
})
# Check if approved
if "APPROVED" in analysis.upper():
print("Image approved!")
break
# Refine prompt based on critique
refinement = claude.messages.create(
model="claude-sonnet-4-6",
max_tokens=500,
messages=[{
"role": "user",
"content": f"""The previous image generation wasn't quite right.
Original request: {user_request}
Prompt used: {refined_prompt}
Critique: {analysis}
Write an improved prompt that addresses the specific issues mentioned
in the critique. Return ONLY the new prompt."""
}]
)
refined_prompt = refinement.content[0].text
print(f"New prompt: {refined_prompt[:100]}...")
return historyWhat You Should Know After Reading This
- What are the four main approaches to image generation (VAE, GAN, auto-regressive, diffusion)?
- How does the VAE reparameterization trick work and why is it necessary?
- How does GAN training work and why is it unstable?
- How do VQ-VAEs turn images into discrete tokens for auto-regressive generation?
- What is the diffusion forward process (adding noise) and reverse process (denoising)?
- What is latent diffusion and why is it more efficient than pixel-space diffusion?
- What is the difference between U-Net and DiT architectures for diffusion?
- How does cross-attention enable text conditioning in T2I models?
- What is classifier-free guidance (CFG) and how does it improve text-image alignment?
- What are FID, IS, and CLIP score, and what does each measure?
- How does a 3D VAE compress video spatially and temporally?
- What is factored attention and why is it necessary for video generation?
- What is the progressive training strategy for T2V models?
- How would you build an agent that uses image generation as a tool?
Further Reading
- "Auto-Encoding Variational Bayes" (Kingma and Welling, 2013) — The original VAE paper
- "Generative Adversarial Nets" (Goodfellow et al., 2014) — The original GAN paper
- "Denoising Diffusion Probabilistic Models" (Ho et al., 2020) — DDPM, the foundational diffusion paper
- "High-Resolution Image Synthesis with Latent Diffusion Models" (Rombach et al., 2022) — Stable Diffusion / Latent Diffusion
- "Scalable Diffusion Models with Transformers" (Peebles and Xie, 2023) — The DiT paper
- "Denoising Diffusion Implicit Models" (Song et al., 2020) — DDIM fast sampling
- "Classifier-Free Diffusion Guidance" (Ho and Salimans, 2022) — CFG
- "Neural Discrete Representation Learning" (van den Oord et al., 2017) — VQ-VAE
- "Video generation models as world simulators" (OpenAI, 2024) — Sora technical report
- "CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer" (Yang et al., 2024)
- "Stable Video Diffusion: Scaling Latent Video Diffusion Models to Large Datasets" (Blattmann et al., 2023)
- "Understanding Diffusion Models: A Unified Perspective" (Luo, 2022) — Excellent tutorial paper
You might also like
Build Your Own GREMLIN IN THE SHELL
A hands-on guide to building your own shell-based AI agent that haunts your terminal and gets things done.
BlogOn Creating an OpenAI Client Clone
Building an OpenAI-compatible API client from the ground up — understanding the protocol, streaming, and tool calling.
BlogMake Your Own Claude Code
How to build your own CLI coding assistant inspired by Claude Code — from terminal UI to tool use to agentic loops.