Back to ML & AI
Budding··44 min read

Build an LLM Playground — Part 5: Build a Multi-modal Generation Agent

The fifth entry in the learn-by-doing AI engineer series. We cover the full landscape of visual generation — VAEs, GANs, auto-regressive models, and diffusion models — then go deep on text-to-image and text-to-video pipelines, and build a multi-modal generation agent.

aidiffusiontext-to-imagetext-to-videostable-diffusionvaegandittutorialseries
Share

Series: The AI Engineer Learning Path

This is Part 5 of a hands-on series designed to take you from zero to working AI engineer. Every post follows a learn-by-doing philosophy — we explain the theory, then you build something real.

PartTopicStatus
1Build an LLM PlaygroundComplete
2Customer Support Chatbot with RAGs & Prompt EngineeringComplete
3"Ask-the-Web" Agent with Tool CallingComplete
4Deep Research with Reasoning ModelsComplete
5Multi-modal Generation Agent (this post)Current

In Parts 1-4, we worked exclusively with text — generating it, retrieving it, reasoning over it, searching for it. Now we cross into the visual world. Image and video generation has exploded over the past three years, moving from research curiosity to production capability. Understanding how these systems work is essential for any AI engineer.

By the end of this post, you'll understand every major approach to visual generation, know how diffusion models work at the math and code level, understand text-to-image and text-to-video pipelines end-to-end, and build a multi-modal generation agent that can create and iterate on images and videos from natural language descriptions.


Part I: Overview of Image and Video Generation

Before diving into diffusion models (the dominant paradigm today), let's understand the full landscape. Each approach represents a different answer to the same fundamental question: how do you learn to generate realistic images from a probability distribution?

The Core Problem

A 256x256 RGB image is a point in a 196,608-dimensional space (256 x 256 x 3 channels). The set of "realistic images" is a tiny manifold within that vast space — most random pixel configurations look like noise. Generative models learn to map from a simple distribution (like Gaussian noise) to this manifold of realistic images.

Random noise (easy to sample)
        ↓
  [Generative Model]
        ↓
Realistic image (hard to sample directly)

Variational Autoencoders (VAE)

VAEs learn a compressed latent space where similar images are near each other, then generate images by decoding points from that space.

Architecture:

Image (256x256x3)
    ↓
[Encoder] → Latent vector z (e.g., 512-dimensional)
                ↓
            [Decoder] → Reconstructed image (256x256x3)

The key insight: The encoder doesn't output a single point — it outputs a distribution (mean and variance). During training, we sample from this distribution, which forces the latent space to be smooth and continuous.

The reparameterization trick:

You can't backpropagate through a random sampling operation. The trick: instead of sampling z ~ N(mu, sigma), compute z = mu + sigma * epsilon where epsilon ~ N(0, 1). Now the randomness is in epsilon (which doesn't depend on model parameters), and gradients flow through mu and sigma.

import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
class VAE(nn.Module):
    """Simple VAE for 28x28 grayscale images (e.g., MNIST)."""
 
    def __init__(self, latent_dim=32):
        super().__init__()
        self.latent_dim = latent_dim
 
        # Encoder: image → (mu, log_var)
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_log_var = nn.Linear(256, latent_dim)
 
        # Decoder: z → image
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28 * 28),
            nn.Sigmoid(),  # pixel values in [0, 1]
        )
 
    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_log_var(h)
 
    def reparameterize(self, mu, log_var):
        """Sample z using the reparameterization trick."""
        std = torch.exp(0.5 * log_var)  # standard deviation
        eps = torch.randn_like(std)      # random noise
        return mu + eps * std            # z = mu + sigma * epsilon
 
    def decode(self, z):
        return self.decoder(z).view(-1, 1, 28, 28)
 
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        reconstruction = self.decode(z)
        return reconstruction, mu, log_var
 
 
def vae_loss(reconstruction, original, mu, log_var):
    """
    VAE loss = Reconstruction loss + KL divergence.
    Reconstruction: How well the decoder reproduces the input.
    KL divergence: How close the latent distribution is to N(0, 1).
    """
    # Binary cross-entropy for reconstruction
    recon_loss = F.binary_cross_entropy(
        reconstruction.view(-1, 28 * 28),
        original.view(-1, 28 * 28),
        reduction='sum'
    )
 
    # KL divergence: D_KL(q(z|x) || p(z))
    # Closed-form for two Gaussians
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
 
    return recon_loss + kl_loss

Training the VAE:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
 
# Load MNIST
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=128, shuffle=True)
 
# Train
model = VAE(latent_dim=32)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
 
for epoch in range(20):
    total_loss = 0
    for batch, _ in loader:
        optimizer.zero_grad()
        recon, mu, log_var = model(batch)
        loss = vae_loss(recon, batch, mu, log_var)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
 
    avg_loss = total_loss / len(dataset)
    print(f"Epoch {epoch + 1}: loss = {avg_loss:.2f}")
 
# Generate new images by sampling from the latent space
with torch.no_grad():
    z = torch.randn(16, 32)  # 16 random latent vectors
    generated = model.decode(z)
    # generated is 16 images of shape (1, 28, 28)

VAE strengths and weaknesses:

StrengthWeakness
Smooth, structured latent spaceGenerated images tend to be blurry
Principled probabilistic frameworkMode collapse — tends to average over possibilities
Fast generation (single forward pass)Reconstruction quality limited by bottleneck
Good for learning representationsNot competitive with diffusion for image quality

Where VAEs are used today: VAEs are critical components inside diffusion models. Stable Diffusion uses a VAE to compress images into a latent space, then runs diffusion in that latent space (much cheaper than pixel space). More on this later.


Generative Adversarial Networks (GANs)

GANs train two networks that compete against each other: a generator that creates fake images and a discriminator that tries to tell real from fake.

Random noise z ──→ [Generator G] ──→ Fake image
                                         ↓
                                    [Discriminator D] ──→ "Real" or "Fake"
                                         ↑
Real image from dataset ─────────────────┘

The minimax game:

  • The generator tries to fool the discriminator (make fakes that look real)
  • The discriminator tries to catch the generator (distinguish real from fake)
  • As they compete, both improve — the generator makes increasingly realistic images
import torch
import torch.nn as nn
 
 
class Generator(nn.Module):
    """Generate 28x28 images from random noise."""
 
    def __init__(self, latent_dim=100):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 28 * 28),
            nn.Tanh(),  # output in [-1, 1]
        )
 
    def forward(self, z):
        return self.net(z).view(-1, 1, 28, 28)
 
 
class Discriminator(nn.Module):
    """Classify images as real or fake."""
 
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
 
    def forward(self, x):
        return self.net(x)
 
 
# Training loop
latent_dim = 100
generator = Generator(latent_dim)
discriminator = Discriminator()
g_optimizer = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
criterion = nn.BCELoss()
 
for epoch in range(50):
    for real_images, _ in loader:
        batch_size = real_images.size(0)
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
 
        # --- Train Discriminator ---
        # Real images should be classified as real
        d_real_output = discriminator(real_images)
        d_real_loss = criterion(d_real_output, real_labels)
 
        # Fake images should be classified as fake
        z = torch.randn(batch_size, latent_dim)
        fake_images = generator(z).detach()  # detach so gradients don't flow to G
        d_fake_output = discriminator(fake_images)
        d_fake_loss = criterion(d_fake_output, fake_labels)
 
        d_loss = d_real_loss + d_fake_loss
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
 
        # --- Train Generator ---
        # Generator wants discriminator to classify fakes as real
        z = torch.randn(batch_size, latent_dim)
        fake_images = generator(z)
        g_output = discriminator(fake_images)
        g_loss = criterion(g_output, real_labels)  # "fool the discriminator"
 
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
 
    print(f"Epoch {epoch + 1}: D_loss = {d_loss:.4f}, G_loss = {g_loss:.4f}")

Key GAN variants:

VariantKey IdeaNotable For
DCGANConvolutional generator/discriminatorFirst stable GAN architecture
StyleGAN (1/2/3)Style-based generator with progressive growingState-of-the-art face generation
Pix2PixPaired image-to-image translationSketch → photo, satellite → map
CycleGANUnpaired image translationHorse → zebra without paired training data
ProGANProgressive resolution increase during trainingHigh-resolution generation

GAN strengths and weaknesses:

StrengthWeakness
Sharp, high-quality imagesTraining instability (mode collapse, oscillation)
Fast generation (single forward pass)Hard to train — requires careful hypertuning
Excellent for specific domains (faces, art styles)Poor diversity — may only learn subset of distribution
Real-time capableNo explicit density estimation

Where GANs stand today: GANs dominated image generation from 2014 to 2021. Diffusion models have since surpassed them in quality and diversity for general image generation. However, GANs remain useful for real-time applications and domain-specific tasks where speed matters.


Auto-regressive Models

Auto-regressive models generate images one piece at a time, predicting each pixel (or token) conditioned on all previously generated pieces. This is the same approach as GPT for text — just applied to images.

The idea:

Text generation:     "The" → "cat" → "sat" → "on" → "the" → "mat"
Image generation:    pixel_1 → pixel_2 → pixel_3 → ... → pixel_65536

Generating 65,536 pixels one by one is slow. The breakthrough was VQ-VAE (Vector Quantized VAE): compress the image into a small grid of discrete tokens, then use a transformer to predict those tokens auto-regressively.

Image (256x256)
    ↓
[VQ-VAE Encoder] → Token grid (32x32 = 1024 tokens)
                        ↓
                [Transformer] → Predicts tokens one by one
                        ↓
[VQ-VAE Decoder] → Reconstructed image (256x256)

VQ-VAE: turning images into tokens

import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
class VectorQuantizer(nn.Module):
    """
    Quantize continuous latent vectors to the nearest codebook entry.
    This creates discrete "image tokens" — like words in a vocabulary.
    """
 
    def __init__(self, num_embeddings=512, embedding_dim=64):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
 
        # Codebook: a learned vocabulary of 512 latent vectors
        self.codebook = nn.Embedding(num_embeddings, embedding_dim)
        self.codebook.weight.data.uniform_(
            -1.0 / num_embeddings, 1.0 / num_embeddings
        )
 
    def forward(self, z):
        """
        z: (batch, channels, height, width) — continuous latent from encoder
        Returns: quantized z, indices (the token IDs), commitment loss
        """
        # Reshape: (B, C, H, W) → (B*H*W, C)
        b, c, h, w = z.shape
        z_flat = z.permute(0, 2, 3, 1).reshape(-1, c)
 
        # Find nearest codebook entry for each vector
        # Distance: ||z - e||^2 = ||z||^2 + ||e||^2 - 2*z*e^T
        distances = (
            z_flat.pow(2).sum(dim=1, keepdim=True)
            + self.codebook.weight.pow(2).sum(dim=1)
            - 2 * z_flat @ self.codebook.weight.t()
        )
        indices = distances.argmin(dim=1)  # closest codebook entry
 
        # Look up the quantized vectors
        z_quantized = self.codebook(indices).view(b, h, w, c).permute(0, 3, 1, 2)
 
        # Losses
        # Commitment loss: encoder output should be close to codebook entries
        commitment_loss = F.mse_loss(z, z_quantized.detach())
        # Codebook loss: codebook entries should be close to encoder output
        codebook_loss = F.mse_loss(z.detach(), z_quantized)
 
        # Straight-through estimator: copy gradients from z_quantized to z
        z_quantized = z + (z_quantized - z).detach()
 
        return z_quantized, indices.view(b, h, w), commitment_loss + 0.25 * codebook_loss

Auto-regressive generation with a transformer:

Once you have the VQ-VAE, image generation becomes a sequence prediction problem:

# Pseudocode: auto-regressive image generation
def generate_image(transformer, vq_vae_decoder, text_prompt, grid_size=32):
    """
    Generate an image token-by-token using a transformer,
    then decode with VQ-VAE decoder.
    """
    # Encode the text prompt
    text_tokens = tokenize(text_prompt)
 
    # Generate image tokens one by one
    image_tokens = []
    for i in range(grid_size * grid_size):
        # Predict next token given text + all previous image tokens
        context = text_tokens + image_tokens
        logits = transformer(context)
        next_token = sample_from(logits[-1])  # sample from the distribution
        image_tokens.append(next_token)
 
    # Decode tokens back to image
    token_grid = reshape(image_tokens, (grid_size, grid_size))
    image = vq_vae_decoder(token_grid)
    return image

Notable auto-regressive image models:

ModelApproachKey Innovation
PixelCNN/PixelRNNPredict pixels directlyFirst auto-regressive image models
VQ-VAE-2Hierarchical VQ-VAE + PixelCNNMulti-scale token grids
DALL-E 1VQ-VAE + GPT-style transformerFirst large-scale text-to-image
Parti (Google)ViT-VQGAN + large transformerScaled auto-regressive T2I
LlamaGenLlama architecture for image tokensLLM architecture for image generation

Auto-regressive strengths and weaknesses:

StrengthWeakness
Unified framework with text (same architecture)Slow generation (sequential token-by-token)
Naturally handles multi-modal sequencesToken grid limits detail (lossy compression)
Well-understood scaling laws from LLMsQuadratic attention cost in long sequences
Easy to add conditioning (text, class labels)Lower quality than diffusion at same compute

Diffusion Models

Diffusion models are the current state of the art. They work by learning to reverse a gradual noising process — starting from pure noise and iteratively removing noise until a clean image emerges.

The forward process (adding noise):

Start with a clean image and gradually add Gaussian noise over T steps until you have pure noise.

Clean image → Slightly noisy → Noisier → ... → Pure Gaussian noise
    x_0          x_1              x_2             x_T

The reverse process (removing noise):

Learn a neural network that predicts and removes the noise at each step, going from pure noise back to a clean image.

Pure noise → Slightly less noisy → ... → Clean image
    x_T          x_{T-1}                    x_0

Why this works: Adding noise is easy (just add random values). The neural network only needs to learn the small step of removing a little noise at each step — a much simpler task than generating an image from scratch in one shot.

import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
class SimpleDiffusion:
    """
    A minimal diffusion model implementation.
    Learns to denoise images step by step.
    """
 
    def __init__(self, num_timesteps=1000):
        self.T = num_timesteps
 
        # Noise schedule: beta controls how much noise is added at each step
        # Linear schedule from 0.0001 to 0.02
        self.betas = torch.linspace(1e-4, 0.02, num_timesteps)
 
        # Precompute useful quantities
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
 
    def forward_process(self, x_0, t, noise=None):
        """
        Add noise to a clean image x_0 at timestep t.
        q(x_t | x_0) = N(sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)
 
        The "nice property": we can jump to any timestep directly
        without iterating through all previous steps.
        """
        if noise is None:
            noise = torch.randn_like(x_0)
 
        sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
 
        # x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
        x_t = sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise
        return x_t, noise
 
    def training_step(self, model, x_0):
        """
        One training step:
        1. Sample a random timestep
        2. Add noise to the image
        3. Predict the noise using the model
        4. Compute the loss (predicted noise vs actual noise)
        """
        batch_size = x_0.shape[0]
 
        # Random timestep for each image in the batch
        t = torch.randint(0, self.T, (batch_size,))
 
        # Add noise
        noise = torch.randn_like(x_0)
        x_t, _ = self.forward_process(x_0, t, noise)
 
        # Predict the noise
        predicted_noise = model(x_t, t)
 
        # Simple MSE loss: predicted noise vs actual noise
        loss = F.mse_loss(predicted_noise, noise)
        return loss
 
    @torch.no_grad()
    def sample(self, model, shape, device='cpu'):
        """
        Generate an image by starting from pure noise
        and iteratively denoising.
        """
        # Start from pure noise
        x = torch.randn(shape, device=device)
 
        # Reverse process: denoise step by step
        for t in reversed(range(self.T)):
            t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
 
            # Predict noise
            predicted_noise = model(x, t_batch)
 
            # Compute denoising step
            alpha = self.alphas[t]
            alpha_bar = self.alphas_cumprod[t]
            beta = self.betas[t]
 
            # Mean of the reverse distribution
            mean = (1 / torch.sqrt(alpha)) * (
                x - (beta / torch.sqrt(1 - alpha_bar)) * predicted_noise
            )
 
            # Add noise (except at the final step)
            if t > 0:
                noise = torch.randn_like(x)
                sigma = torch.sqrt(beta)
                x = mean + sigma * noise
            else:
                x = mean
 
        return x

The noise prediction network (simplified U-Net):

class TimeEmbedding(nn.Module):
    """Encode the timestep t as a vector so the network knows 'how noisy' the input is."""
 
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )
 
    def forward(self, t):
        # Sinusoidal position encoding (same as transformers)
        half_dim = self.dim // 2
        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t.float().unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        return self.mlp(emb)
 
 
class SimpleUNet(nn.Module):
    """
    Simplified U-Net for noise prediction.
    Takes a noisy image and timestep, predicts the noise.
    """
 
    def __init__(self, in_channels=1, base_channels=64, time_dim=256):
        super().__init__()
        self.time_embed = TimeEmbedding(time_dim)
 
        # Encoder (downsampling)
        self.down1 = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, 3, padding=1),
            nn.GroupNorm(8, base_channels),
            nn.GELU(),
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 2, 3, stride=2, padding=1),
            nn.GroupNorm(8, base_channels * 2),
            nn.GELU(),
        )
 
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1),
            nn.GroupNorm(8, base_channels * 2),
            nn.GELU(),
        )
 
        # Time conditioning: project time embedding to channel dimension
        self.time_proj = nn.Linear(time_dim, base_channels * 2)
 
        # Decoder (upsampling) with skip connections
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(base_channels * 4, base_channels, 2, stride=2),
            nn.GroupNorm(8, base_channels),
            nn.GELU(),
        )
        self.out = nn.Conv2d(base_channels * 2, in_channels, 1)
 
    def forward(self, x, t):
        # Time embedding
        t_emb = self.time_embed(t)
 
        # Encoder
        h1 = self.down1(x)          # (B, 64, H, W)
        h2 = self.down2(h1)         # (B, 128, H/2, W/2)
 
        # Add time conditioning to bottleneck
        t_proj = self.time_proj(t_emb).unsqueeze(-1).unsqueeze(-1)
        h = self.bottleneck(h2) + t_proj  # (B, 128, H/2, W/2)
 
        # Decoder with skip connections
        h = torch.cat([h, h2], dim=1)   # (B, 256, H/2, W/2)
        h = self.up1(h)                  # (B, 64, H, W)
        h = torch.cat([h, h1], dim=1)   # (B, 128, H, W)
        return self.out(h)               # (B, 1, H, W) — predicted noise

Training the diffusion model:

# Train on MNIST
diffusion = SimpleDiffusion(num_timesteps=1000)
model = SimpleUNet(in_channels=1, base_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
 
for epoch in range(50):
    total_loss = 0
    for batch, _ in loader:
        optimizer.zero_grad()
        loss = diffusion.training_step(model, batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
 
    print(f"Epoch {epoch + 1}: loss = {total_loss / len(loader):.4f}")
 
# Generate images
samples = diffusion.sample(model, shape=(16, 1, 28, 28))
# samples contains 16 generated 28x28 images

Why diffusion models won:

Advantage over GANsAdvantage over VAEsAdvantage over Auto-regressive
Stable training (no adversarial dynamics)Much sharper images (no blurriness)Parallel generation (all pixels at once)
Better mode coverage (more diverse)Higher quality reconstructionFaster generation per image
Mathematically principledBetter latent space for manipulationBetter scaling with compute

Summary: Generation Approaches

ApproachHow It GeneratesSpeedQualityKey Use Today
VAEDecode from latent spaceVery fastModerate (blurry)Compression component in diffusion models
GANGenerator fools discriminatorVery fastHigh (sharp)Real-time style transfer, face editing
Auto-regressiveToken by tokenSlowHighMulti-modal models (text + image in one model)
DiffusionIterative denoisingModerateVery highDALL-E 3, Stable Diffusion, Midjourney, Sora

Part II: Text-to-Image (T2I)

Text-to-image is the most commercially impactful application of generative models. Let's understand how modern T2I systems work end-to-end.

Data Preparation

T2I models need millions of image-text pairs. The quality and diversity of this data determines the quality of the model.

Major datasets:

DatasetSizeSourceNotes
LAION-5B5.85 billion pairsCommon Crawl (web scraping)Largest open dataset, used by Stable Diffusion
LAION-Aesthetics~600M pairsFiltered LAION-5BSubset with high aesthetic scores
COYO-700M700M pairsWeb scrapingKorean-origin, multilingual
DataComp1.4B pairsCurated from Common CrawlFocus on data quality over quantity

Data pipeline:

1. Scrape image-text pairs from the web (alt text, captions)
        ↓
2. Filter for quality:
   - Remove images smaller than 256x256
   - Remove duplicates (perceptual hashing)
   - Remove NSFW content (CLIP-based classifier)
   - Remove watermarked images
   - Remove images with low aesthetic scores
        ↓
3. Clean captions:
   - Remove boilerplate ("Click here to enlarge")
   - Standardize formatting
   - Optionally re-caption with a vision-language model (BLIP-2, LLaVA)
        ↓
4. Compute CLIP embeddings for efficient filtering and retrieval
        ↓
5. Store as WebDataset shards for efficient training I/O

Re-captioning with a vision model:

Web alt-text is often low quality ("IMG_2847.jpg", "photo", "banner image"). Modern pipelines use a vision-language model to generate better captions:

# Pseudocode: re-caption images with a VLM
from transformers import Blip2Processor, Blip2ForConditionalGeneration
 
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
vlm = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
 
def recaption_image(image):
    """Generate a detailed caption for an image."""
    inputs = processor(image, "Describe this image in detail:", return_tensors="pt")
    output = vlm.generate(**inputs, max_new_tokens=100)
    caption = processor.decode(output[0], skip_special_tokens=True)
    return caption
 
# Original alt-text: "photo"
# Re-captioned: "A golden retriever sitting on a wooden dock at sunset,
#                with a calm lake and mountains in the background"

Diffusion Architectures

Two main architectures are used for the denoising network:

U-Net Architecture (Stable Diffusion 1.x/2.x, DALL-E 2)

The U-Net is an encoder-decoder with skip connections. It was originally designed for medical image segmentation and adapted for diffusion.

Input (noisy latent)
    ↓
[Down Block 1] ──────────────────────────────→ [Up Block 3] (skip connection)
    ↓                                              ↑
[Down Block 2] ──────────────────────────────→ [Up Block 2] (skip connection)
    ↓                                              ↑
[Down Block 3] ──────────────────────────────→ [Up Block 1] (skip connection)
    ↓                                              ↑
              [Middle Block (attention)]

Each block contains:

  • Residual convolution layers
  • Self-attention layers (for global context)
  • Cross-attention layers (for text conditioning)
  • Time embedding injection

Cross-attention for text conditioning:

This is how the model "understands" the text prompt. The image features attend to the text embeddings:

class CrossAttention(nn.Module):
    """Image features attend to text embeddings."""
 
    def __init__(self, dim, context_dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
 
        self.to_q = nn.Linear(dim, dim)           # queries from image
        self.to_k = nn.Linear(context_dim, dim)    # keys from text
        self.to_v = nn.Linear(context_dim, dim)    # values from text
        self.to_out = nn.Linear(dim, dim)
 
    def forward(self, x, context):
        """
        x: image features (batch, seq_len, dim)
        context: text embeddings (batch, text_len, context_dim)
        """
        b, n, _ = x.shape
        h = self.num_heads
 
        q = self.to_q(x).view(b, n, h, self.head_dim).transpose(1, 2)
        k = self.to_k(context).view(b, -1, h, self.head_dim).transpose(1, 2)
        v = self.to_v(context).view(b, -1, h, self.head_dim).transpose(1, 2)
 
        # Attention: image queries attend to text keys
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
 
        out = (attn @ v).transpose(1, 2).reshape(b, n, -1)
        return self.to_out(out)

DiT Architecture (Stable Diffusion 3, FLUX, DALL-E 3)

Diffusion Transformer (DiT) replaces the U-Net with a pure transformer architecture. Instead of convolutions, the image is split into patches (like Vision Transformer) and processed with transformer blocks.

Noisy latent (e.g., 64x64x4)
    ↓
[Patchify] → sequence of patch tokens (e.g., 1024 tokens for 32x32 patches)
    ↓
[+ positional embeddings]
    ↓
[DiT Block 1] ← timestep embedding + text embeddings
    ↓
[DiT Block 2] ← timestep embedding + text embeddings
    ↓
  ... (24-40 blocks)
    ↓
[DiT Block N]
    ↓
[Unpatchify] → predicted noise (64x64x4)

DiT block with adaptive layer norm (adaLN-Zero):

class DiTBlock(nn.Module):
    """
    A single DiT block. Uses adaptive layer norm to inject
    timestep and text conditioning.
    """
 
    def __init__(self, dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim * mlp_ratio), dim),
        )
 
        # adaLN-Zero: learn scale and shift from conditioning
        # 6 parameters per block: gamma1, beta1, alpha1, gamma2, beta2, alpha2
        self.adaln = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, 6 * dim),
        )
 
    def forward(self, x, conditioning):
        """
        x: patch tokens (batch, seq_len, dim)
        conditioning: timestep + text embedding (batch, dim)
        """
        # Get adaptive layer norm parameters
        adaln_params = self.adaln(conditioning).unsqueeze(1)
        gamma1, beta1, alpha1, gamma2, beta2, alpha2 = adaln_params.chunk(6, dim=-1)
 
        # Self-attention with adaLN
        h = self.norm1(x) * (1 + gamma1) + beta1
        h, _ = self.attn(h, h, h)
        x = x + alpha1 * h
 
        # MLP with adaLN
        h = self.norm2(x) * (1 + gamma2) + beta2
        h = self.mlp(h)
        x = x + alpha2 * h
 
        return x

U-Net vs DiT:

AspectU-NetDiT
ArchitectureCNN + attentionPure transformer
ScalingHard to scale uniformlyScales like LLMs (more layers, more heads)
Inductive biasStrong spatial bias from convolutionsMinimal bias — learned from data
Used byStable Diffusion 1.x/2.x, DALL-E 2Stable Diffusion 3, FLUX, DALL-E 3, Sora
ComputeMore efficient at lower scalesMore efficient at larger scales
TrendLegacyCurrent standard

Diffusion Training

The Forward Process (Adding Noise)

Given a clean image x_0, the forward process produces a noisy version x_t by adding Gaussian noise scaled by the timestep t:

q(x_t | x_0) = N(x_t; sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)

In plain English: x_t is a weighted mix of the original image and random noise, where the weight depends on the timestep.

At t=0, x_t is almost entirely the original image. At t=T, x_t is almost entirely noise.

# The forward process in one line:
x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise

Noise schedules control how fast noise is added:

ScheduleFormulaCharacteristics
Linearbeta_t = beta_min + t/T * (beta_max - beta_min)Simple, used in original DDPM
Cosinealpha_bar_t = cos((t/T + s) / (1+s) * pi/2)^2Smoother, better for high resolution
Scaled linearAdjusted for latent space diffusionUsed in Stable Diffusion

The Backward Process (Learning to Denoise)

The model learns to predict the noise that was added, so it can be subtracted:

Training objective:
  L = E[||noise - model(x_t, t, text)||^2]

In words: The loss is the mean squared error between
the actual noise that was added and the noise the model predicts.

Training algorithm (simplified):

Repeat:
  1. Sample a clean image x_0 from the dataset
  2. Sample a random timestep t ~ Uniform(1, T)
  3. Sample random noise epsilon ~ N(0, I)
  4. Compute noisy image: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
  5. Predict noise: epsilon_pred = model(x_t, t, text_embedding)
  6. Compute loss: L = ||epsilon - epsilon_pred||^2
  7. Update model weights via gradient descent

What the model actually predicts — there are three equivalent formulations:

Prediction TargetWhat the Model OutputsNotes
Noise (epsilon)The noise added at step tMost common (DDPM, Stable Diffusion 1.x)
Clean image (x_0)The denoised image directlyUsed in some formulations
Velocity (v)v = sqrt(alpha_bar_t) * epsilon - sqrt(1 - alpha_bar_t) * x_0Better numerical stability, used in SD 2.x+

Latent diffusion (what Stable Diffusion actually does):

Running diffusion in pixel space (256x256x3) is very expensive. Latent diffusion first compresses the image with a VAE, then runs diffusion in the latent space:

Pixel space diffusion:  256x256x3 = 196,608 dimensions
Latent space diffusion: 32x32x4   = 4,096 dimensions  (48x cheaper!)

Image → [VAE Encoder] → Latent → [Diffusion in latent space] → Latent → [VAE Decoder] → Image
# Pseudocode: Latent Diffusion Training
def train_step(vae, diffusion_model, text_encoder, image, caption):
    # 1. Encode image to latent space (VAE is frozen — pretrained)
    with torch.no_grad():
        latent = vae.encode(image).latent_dist.sample()
        latent = latent * 0.18215  # scaling factor
 
    # 2. Encode text caption
    with torch.no_grad():
        text_emb = text_encoder(tokenize(caption))
 
    # 3. Standard diffusion training in latent space
    t = torch.randint(0, T, (batch_size,))
    noise = torch.randn_like(latent)
    noisy_latent = forward_process(latent, t, noise)
 
    # 4. Predict noise, conditioned on text
    noise_pred = diffusion_model(noisy_latent, t, text_emb)
 
    # 5. Loss
    loss = F.mse_loss(noise_pred, noise)
    return loss

Diffusion Sampling

At inference time, we start from pure noise and iteratively denoise. The sampling strategy significantly affects speed and quality.

DDPM Sampling (Original)

The original approach: denoise one step at a time through all T timesteps (typically T=1000).

# DDPM sampling: 1000 steps (slow but high quality)
x = torch.randn(1, 4, 64, 64)  # pure noise in latent space
 
for t in reversed(range(1000)):
    noise_pred = model(x, t, text_embedding)
    x = denoise_step(x, noise_pred, t)  # one denoising step
 
image = vae.decode(x)  # decode latent to pixel image

Problem: 1000 neural network forward passes per image is very slow.

DDIM Sampling (Faster)

DDIM (Denoising Diffusion Implicit Models) allows you to skip steps — denoise in 20-50 steps instead of 1000.

# DDIM sampling: skip steps for 50x speedup
timesteps = [999, 979, 959, ..., 19, 0]  # 50 evenly spaced steps
 
x = torch.randn(1, 4, 64, 64)
for t in timesteps:
    noise_pred = model(x, t, text_embedding)
    x = ddim_step(x, noise_pred, t, t_prev)  # deterministic step
 
image = vae.decode(x)

DDIM is deterministic — given the same starting noise, you always get the same image. This enables useful features like interpolation between images.

Classifier-Free Guidance (CFG)

The most important trick in modern T2I. CFG amplifies the effect of the text prompt by comparing conditioned and unconditioned predictions:

noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
  • guidance_scale = 1.0: No guidance (model's natural output)
  • guidance_scale = 7.5: Standard (good balance of quality and diversity)
  • guidance_scale = 15.0+: Very strong guidance (high prompt adherence, less diversity)
def guided_denoise_step(model, x_t, t, text_embedding, guidance_scale=7.5):
    """
    Classifier-free guidance: run the model twice (with and without text),
    then amplify the difference.
    """
    # Unconditional prediction (empty text)
    noise_uncond = model(x_t, t, empty_text_embedding)
 
    # Conditional prediction (with text)
    noise_cond = model(x_t, t, text_embedding)
 
    # Guided prediction
    noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
    return noise_pred

During training, CFG requires randomly dropping the text conditioning (replacing it with an empty string) some percentage of the time (typically 10-20%). This teaches the model to generate both conditionally and unconditionally.

Modern samplers (DPM-Solver, UniPC, Euler, etc.) can produce good results in 20-30 steps by using higher-order ODE solvers.


Evaluation

How do we measure whether a T2I model is actually good? There are several complementary metrics.

MetricWhat It MeasuresHow It WorksGood Score
FID (Frechet Inception Distance)Image quality + diversityCompare statistics of generated vs real images in Inception feature spaceLower is better (good models: 5-15)
IS (Inception Score)Image quality + diversityHow confidently Inception classifies generated images, and how diverse those classifications areHigher is better (good models: 50-200+)
CLIP ScoreImage-text alignmentCosine similarity between CLIP embeddings of the image and text promptHigher is better (0-1 scale)
Aesthetic ScoreVisual appealTrained predictor of human aesthetic preferenceHigher is better (1-10 scale)
Human EvaluationOverall qualityHumans compare and rate generated imagesGold standard but expensive
# Computing CLIP score
import torch
from transformers import CLIPProcessor, CLIPModel
 
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
def clip_score(image, text):
    """Measure how well the image matches the text description."""
    inputs = clip_processor(text=[text], images=[image], return_tensors="pt")
    outputs = clip_model(**inputs)
 
    # Cosine similarity between image and text embeddings
    image_emb = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
    text_emb = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
 
    score = (image_emb @ text_emb.T).item()
    return score
 
# Example
score = clip_score(generated_image, "a cat wearing a top hat in a garden")
print(f"CLIP score: {score:.3f}")
# Computing FID (using pytorch-fid)
# pip install pytorch-fid
 
# Generate 10,000 images and save to a directory
# Then compare against a directory of real images:
# python -m pytorch_fid path/to/real_images path/to/generated_images

Evaluation pitfalls:

PitfallWhy It's a Problem
FID is sensitive to sample sizeNeed at least 10K-50K images for stable FID
IS doesn't measure text alignmentHigh IS just means sharp, diverse images
CLIP score can be gamedModel could overfit to CLIP's biases
Single metrics are insufficientAlways use multiple metrics + human evaluation

Part III: Text-to-Video (T2V)

Text-to-video extends diffusion to the temporal dimension. Instead of generating a single image, we generate a sequence of frames that are temporally coherent — objects move smoothly, lighting changes gradually, and the scene makes physical sense across time.

Why Video Is Harder Than Images

ChallengeWhy It's Hard
Temporal coherenceEach frame must be consistent with the previous frame — characters shouldn't teleport or morph
Data volumeA 4-second video at 24fps is 96 frames. At 256x256, that's 18.9M pixels per clip (vs 196K for a single image)
Motion understandingThe model must learn physics, object permanence, and natural motion from data
ComputeTraining and inference costs scale linearly (or worse) with video length
EvaluationHarder to measure — temporal quality, motion naturalness, scene consistency

Latent Diffusion Modeling (LDM) and Compression Networks

Just as Stable Diffusion runs diffusion in a compressed latent space for images, video diffusion models use compression networks to reduce video to a manageable latent representation.

Image VAE vs Video VAE (3D VAE):

Image VAE:
  Image (256x256x3) → [2D Encoder] → Latent (32x32x4)
  Compression: spatial only (8x downsampling)

Video VAE (3D VAE):
  Video (T x 256x256x3) → [3D Encoder] → Latent (T/4 x 32x32x4)
  Compression: spatial (8x) AND temporal (4x)

The temporal compression is crucial: a 96-frame video becomes a 24-frame latent sequence, making diffusion tractable.

# Pseudocode: 3D VAE for video compression
class Video3DVAE(nn.Module):
    """
    Compress video spatially AND temporally.
    Uses 3D convolutions to handle the time dimension.
    """
 
    def __init__(self, spatial_downsample=8, temporal_downsample=4):
        super().__init__()
        # 3D convolutions: (time, height, width)
        self.encoder = nn.Sequential(
            # Spatial + temporal downsampling
            nn.Conv3d(3, 64, kernel_size=(3, 4, 4), stride=(1, 2, 2), padding=(1, 1, 1)),
            nn.SiLU(),
            nn.Conv3d(64, 128, kernel_size=(3, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.SiLU(),
            nn.Conv3d(128, 256, kernel_size=(3, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.SiLU(),
            nn.Conv3d(256, 4, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
            # Output: (B, 4, T/4, H/8, W/8)
        )
 
        self.decoder = nn.Sequential(
            # Reverse the compression
            nn.ConvTranspose3d(4, 256, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.SiLU(),
            nn.ConvTranspose3d(256, 128, kernel_size=(3, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.SiLU(),
            nn.ConvTranspose3d(128, 64, kernel_size=(3, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.SiLU(),
            nn.ConvTranspose3d(64, 3, kernel_size=(3, 4, 4), stride=(1, 2, 2), padding=(1, 1, 1)),
        )
 
    def encode(self, video):
        """video: (B, C, T, H, W) → latent: (B, 4, T/4, H/8, W/8)"""
        return self.encoder(video)
 
    def decode(self, latent):
        """latent: (B, 4, T/4, H/8, W/8) → video: (B, C, T, H, W)"""
        return self.decoder(latent)

Training the video VAE:

  • Reconstruction loss (MSE or perceptual loss between input and decoded video)
  • KL divergence loss (keep latent close to Gaussian)
  • Temporal consistency loss (ensure decoded frames are smooth)
  • Adversarial loss (optional — use a discriminator for sharper outputs)

Data Preparation for T2V

Video data is far messier and more expensive to process than image data.

Data pipeline:

1. Source collection
   - Stock video libraries (Shutterstock, Pexels)
   - Web-scraped video-text pairs
   - Internal datasets
        ↓
2. Filtering
   - Remove static videos (no motion)
   - Remove slide shows and text-heavy videos
   - Remove videos with excessive camera shake
   - Scene-cut detection: split long videos into single-scene clips
   - Motion scoring: compute optical flow, filter low/excessive motion
   - Aesthetic filtering: score visual quality per frame
        ↓
3. Standardization
   - Resize to target resolution (e.g., 512x512)
   - Normalize frame rate (e.g., 24 fps)
   - Fixed clip length (e.g., 4 seconds = 96 frames)
   - Center crop or pad to square aspect ratio
        ↓
4. Captioning
   - Use a video captioning model (e.g., InternVideo, LLaVA-Video)
   - Generate both short captions ("a cat jumping on a table")
     and detailed descriptions ("a tabby cat with orange fur leaps from
     the floor onto a wooden kitchen table, knocking over a glass of water")
        ↓
5. Video latent caching
   - Pre-encode all videos through the 3D VAE
   - Store latent representations to disk
   - This avoids re-encoding during training (major speedup)

Video latent caching:

# Pseudocode: pre-compute and cache video latents
def cache_video_latents(dataset, vae, output_dir):
    """
    Pre-encode all training videos through the VAE and cache to disk.
    This is a one-time cost that saves enormous compute during training.
    """
    vae.eval()
 
    for idx, (video, caption) in enumerate(dataset):
        with torch.no_grad():
            # video: (C, T, H, W) → latent: (4, T/4, H/8, W/8)
            latent = vae.encode(video.unsqueeze(0)).squeeze(0)
 
        # Save latent and caption
        torch.save({
            'latent': latent,
            'caption': caption,
        }, f"{output_dir}/sample_{idx:08d}.pt")
 
    print(f"Cached {len(dataset)} video latents")
 
 
# During training, load latents instead of raw video:
class CachedVideoDataset:
    def __init__(self, cache_dir):
        self.files = sorted(glob(f"{cache_dir}/*.pt"))
 
    def __getitem__(self, idx):
        data = torch.load(self.files[idx])
        return data['latent'], data['caption']

DiT Architecture for Videos

Extending DiT to video means handling an additional time dimension. The key architecture choices:

Patch embedding for video:

Video latent: (T, H, W, C) → e.g., (24, 32, 32, 4)
    ↓
Patchify: split into (t, h, w) patches, e.g., (2, 2, 2)
    ↓
Tokens: (T/2 * H/2 * W/2) = 12 * 16 * 16 = 3,072 tokens
    ↓
Linear projection: each 3D patch → token vector of dimension d

Attention strategies for video:

Full 3D attention over all 3,072 tokens is prohibitively expensive (quadratic in sequence length). Modern video DiTs use factored attention:

Option A: Full 3D attention (prohibitive at scale)
  Every token attends to every other token
  Cost: O((T*H*W)^2) — too expensive for long videos

Option B: Factored attention (practical)
  Alternate between:
  1. Spatial attention: tokens within the same frame attend to each other
  2. Temporal attention: tokens at the same spatial position attend across frames

  Cost: O(T * (H*W)^2 + H*W * T^2) — much cheaper
class VideoTransformerBlock(nn.Module):
    """
    DiT block for video: alternates spatial and temporal attention.
    """
 
    def __init__(self, dim, num_heads):
        super().__init__()
 
        # Spatial self-attention (within each frame)
        self.spatial_norm = nn.LayerNorm(dim)
        self.spatial_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
 
        # Temporal self-attention (across frames at same spatial position)
        self.temporal_norm = nn.LayerNorm(dim)
        self.temporal_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
 
        # Cross-attention for text conditioning
        self.cross_norm = nn.LayerNorm(dim)
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
 
        # Feed-forward
        self.ff_norm = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )
 
    def forward(self, x, text_emb, num_frames, h, w):
        """
        x: (batch, num_frames * h * w, dim) — flattened video tokens
        text_emb: (batch, text_len, dim) — text conditioning
        """
        b = x.shape[0]
        thw = num_frames * h * w
 
        # 1. Spatial attention: reshape so each frame is a separate sequence
        # (batch, T*H*W, dim) → (batch*T, H*W, dim)
        x_spatial = x.view(b * num_frames, h * w, -1)
        h_spatial = self.spatial_norm(x_spatial)
        h_spatial, _ = self.spatial_attn(h_spatial, h_spatial, h_spatial)
        x = x + h_spatial.view(b, thw, -1)
 
        # 2. Temporal attention: reshape so each spatial position is a sequence across time
        # (batch, T*H*W, dim) → (batch*H*W, T, dim)
        x_temp = x.view(b, num_frames, h * w, -1).permute(0, 2, 1, 3).reshape(b * h * w, num_frames, -1)
        h_temp = self.temporal_norm(x_temp)
        h_temp, _ = self.temporal_attn(h_temp, h_temp, h_temp)
        x = x + h_temp.reshape(b, h * w, num_frames, -1).permute(0, 2, 1, 3).reshape(b, thw, -1)
 
        # 3. Cross-attention with text
        h_cross = self.cross_norm(x)
        h_cross, _ = self.cross_attn(h_cross, text_emb, text_emb)
        x = x + h_cross
 
        # 4. Feed-forward
        x = x + self.ff(self.ff_norm(x))
 
        return x

Large-Scale Training Challenges

Training video generation models is one of the most compute-intensive tasks in AI.

ChallengeScaleMitigation
ComputeSora reportedly trained on 10,000+ H100 GPUsModel parallelism, gradient checkpointing, mixed precision
MemoryA single 4-second, 1080p video latent can be 100+ MBGradient accumulation, activation checkpointing, offloading
DataMillions of high-quality video-text pairsAggressive filtering, synthetic captions, progressive training
StabilityTraining can diverge at large scaleLearning rate warmup, gradient clipping, loss spike detection
EvaluationNo single metric captures video qualityAutomated metrics + human eval, temporal quality metrics

Progressive training strategy:

Instead of training at full resolution and duration from the start, gradually increase complexity:

Phase 1: Image pretraining
  Train the DiT on images (single frame)
  Resolution: 256x256
  Duration: 10% of total training compute

Phase 2: Short, low-resolution video
  Fine-tune on 16-frame, 256x256 video
  The model learns basic temporal dynamics
  Duration: 30% of total training compute

Phase 3: Longer, higher-resolution video
  Fine-tune on 64-frame, 512x512 video
  The model learns longer temporal consistency
  Duration: 40% of total training compute

Phase 4: Full resolution
  Fine-tune on 96+ frame, 1080p video
  Duration: 20% of total training compute

Mixed-resolution training: Train on multiple resolutions simultaneously. Pack shorter/smaller videos into the same batch as longer/larger ones to maximize GPU utilization.


T2V Overall System

A complete text-to-video system has many components beyond the diffusion model itself:

┌──────────────────────────────────────────────────────────────────────────┐
│                    Text-to-Video System                                  │
│                                                                          │
│  "A golden retriever running through a field of sunflowers at sunset"    │
│       ↓                                                                  │
│  ┌──────────────────────────────────┐                                   │
│  │  Text Encoder (e.g., T5-XXL)     │                                   │
│  │  Converts text to embeddings     │                                   │
│  └──────────────┬───────────────────┘                                   │
│                 ↓                                                        │
│  ┌──────────────────────────────────┐                                   │
│  │  Video DiT                        │                                   │
│  │  Iterative denoising in latent    │  ← Noise schedule                │
│  │  space with text conditioning     │  ← CFG guidance                  │
│  │  (20-50 sampling steps)           │  ← Sampler (DPM, Euler, etc.)   │
│  └──────────────┬───────────────────┘                                   │
│                 ↓                                                        │
│  ┌──────────────────────────────────┐                                   │
│  │  3D VAE Decoder                   │                                   │
│  │  Latent → pixel-space video       │                                   │
│  └──────────────┬───────────────────┘                                   │
│                 ↓                                                        │
│  ┌──────────────────────────────────┐                                   │
│  │  Post-Processing                  │                                   │
│  │  - Frame interpolation (24→60fps) │                                   │
│  │  - Super-resolution (optional)    │                                   │
│  │  - Temporal smoothing             │                                   │
│  └──────────────┬───────────────────┘                                   │
│                 ↓                                                        │
│  Final video (4-10 seconds, 1080p, 24-60fps)                            │
└──────────────────────────────────────────────────────────────────────────┘

Key components:

ComponentRoleExamples
Text encoderConvert prompt to embeddings that condition the diffusion modelT5-XXL, CLIP text encoder, or both
Video DiTThe core diffusion model — denoises video latentsCustom architecture per company
3D VAECompress/decompress video to/from latent spaceTrained separately
SamplerAlgorithm for stepping through the denoising processDDIM, DPM-Solver++, Euler
Super-resolutionUpscale to final resolutionSeparate diffusion model or upscaler
Frame interpolationIncrease frame rateFILM, RIFE

Notable T2V models:

ModelOrganizationKey Innovation
SoraOpenAIScaling DiT to long, high-res video; strong physics understanding
KlingKuaishouCompetitive with Sora; publicly available
Runway Gen-3RunwayMotion control, camera movement specification
Stable Video DiffusionStability AIOpen-weight video diffusion model
CogVideoXTsinghua/ZhipuOpen-weight, strong temporal coherence
Veo 2Google DeepMindHigh fidelity, long-form video
WanAlibabaOpen-weight, 14B parameter video DiT

Part IV: Build a Multi-modal Generation Agent

Now let's build an agent that uses image generation models as tools. This agent can understand text descriptions, generate images, refine them based on feedback, and orchestrate a multi-step creative workflow.

Architecture

┌──────────────────────────────────────────────────────────────┐
│              Multi-modal Generation Agent                      │
│                                                                │
│  User: "Create a logo for a coffee shop called 'Bean There'"  │
│       ↓                                                        │
│  ┌───────────────────────────────────────────────────┐        │
│  │  LLM Orchestrator (Claude)                        │        │
│  │                                                   │        │
│  │  1. Analyze the request                           │        │
│  │  2. Craft an optimized image generation prompt    │        │
│  │  3. Call image generation tool                    │        │
│  │  4. Describe the result to the user               │        │
│  │  5. Iterate based on feedback                     │        │
│  └───────────┬───────────────────────────────────────┘        │
│              │                                                 │
│  Tools:      │                                                 │
│  ┌───────────▼───────────┐  ┌──────────────────────┐         │
│  │  generate_image       │  │  edit_image           │         │
│  │  (DALL-E / SD API)    │  │  (inpainting)         │         │
│  └───────────────────────┘  └──────────────────────┘         │
│  ┌───────────────────────┐  ┌──────────────────────┐         │
│  │  analyze_image        │  │  generate_video       │         │
│  │  (vision model)       │  │  (Runway / Kling API) │         │
│  └───────────────────────┘  └──────────────────────┘         │
└──────────────────────────────────────────────────────────────┘

Implementation

import anthropic
import openai
import json
import base64
import httpx
from pathlib import Path
 
# --- Clients ---
claude = anthropic.Anthropic()
oai = openai.OpenAI()
 
# --- Tool Definitions ---
tools = [
    {
        "name": "generate_image",
        "description": (
            "Generate an image from a text prompt using DALL-E. "
            "The prompt should be detailed and descriptive. "
            "Returns a URL to the generated image."
        ),
        "input_schema": {
            "type": "object",
            "properties": {
                "prompt": {
                    "type": "string",
                    "description": "Detailed description of the image to generate"
                },
                "size": {
                    "type": "string",
                    "enum": ["1024x1024", "1792x1024", "1024x1792"],
                    "description": "Image dimensions"
                },
                "style": {
                    "type": "string",
                    "enum": ["vivid", "natural"],
                    "description": "vivid for hyper-real/dramatic, natural for realistic"
                }
            },
            "required": ["prompt"]
        }
    },
    {
        "name": "generate_image_sdxl",
        "description": (
            "Generate an image using Stable Diffusion XL via the Stability API. "
            "Good for artistic styles, illustrations, and when you need more "
            "control over the generation process."
        ),
        "input_schema": {
            "type": "object",
            "properties": {
                "prompt": {
                    "type": "string",
                    "description": "Detailed description of the image to generate"
                },
                "negative_prompt": {
                    "type": "string",
                    "description": "What to avoid in the image (e.g., 'blurry, low quality')"
                },
                "style_preset": {
                    "type": "string",
                    "enum": [
                        "photographic", "digital-art", "comic-book",
                        "fantasy-art", "analog-film", "cinematic"
                    ],
                    "description": "Visual style preset"
                }
            },
            "required": ["prompt"]
        }
    },
    {
        "name": "analyze_image",
        "description": (
            "Analyze an image using Claude's vision capabilities. "
            "Can describe what's in the image, evaluate quality, "
            "suggest improvements, or answer questions about it."
        ),
        "input_schema": {
            "type": "object",
            "properties": {
                "image_url": {
                    "type": "string",
                    "description": "URL of the image to analyze"
                },
                "question": {
                    "type": "string",
                    "description": "What to analyze or ask about the image"
                }
            },
            "required": ["image_url", "question"]
        }
    },
    {
        "name": "refine_prompt",
        "description": (
            "Take a user's simple description and expand it into an "
            "optimized prompt for image generation. Adds details about "
            "lighting, composition, style, and technical parameters."
        ),
        "input_schema": {
            "type": "object",
            "properties": {
                "user_description": {
                    "type": "string",
                    "description": "The user's original, possibly brief description"
                },
                "style_preference": {
                    "type": "string",
                    "description": "Preferred style (photorealistic, illustration, etc.)"
                }
            },
            "required": ["user_description"]
        }
    }
]
 
 
# --- Tool Implementations ---
def generate_image_dalle(prompt, size="1024x1024", style="vivid"):
    """Generate an image using DALL-E 3."""
    response = oai.images.generate(
        model="dall-e-3",
        prompt=prompt,
        size=size,
        style=style,
        n=1,
    )
    return json.dumps({
        "image_url": response.data[0].url,
        "revised_prompt": response.data[0].revised_prompt
    })
 
 
def generate_image_sdxl(prompt, negative_prompt="", style_preset="photographic"):
    """Generate an image using Stability AI's SDXL API."""
    response = httpx.post(
        "https://api.stability.ai/v2beta/stable-image/generate/sd3",
        headers={"Authorization": f"Bearer {STABILITY_API_KEY}"},
        files={"none": ""},
        data={
            "prompt": prompt,
            "negative_prompt": negative_prompt or "blurry, low quality, distorted",
            "style_preset": style_preset,
            "output_format": "png",
        }
    )
    # Save image and return path
    path = f"generated_{hash(prompt)}.png"
    Path(path).write_bytes(response.content)
    return json.dumps({"image_path": path, "prompt_used": prompt})
 
 
def analyze_image(image_url, question):
    """Analyze an image using Claude's vision."""
    response = claude.messages.create(
        model="claude-sonnet-4-6",
        max_tokens=1024,
        messages=[{
            "role": "user",
            "content": [
                {"type": "image", "source": {"type": "url", "url": image_url}},
                {"type": "text", "text": question}
            ]
        }]
    )
    return response.content[0].text
 
 
def refine_prompt(user_description, style_preference=""):
    """Expand a simple description into a detailed generation prompt."""
    response = claude.messages.create(
        model="claude-sonnet-4-6",
        max_tokens=500,
        messages=[{
            "role": "user",
            "content": f"""Expand this image description into a detailed prompt
optimized for AI image generation. Add specific details about:
- Composition and framing
- Lighting and atmosphere
- Color palette
- Style and medium
- Important details to include
 
User description: {user_description}
Style preference: {style_preference or "not specified"}
 
Return ONLY the expanded prompt, nothing else."""
        }]
    )
    return response.content[0].text
 
 
def execute_tool(name, args):
    """Route tool calls to implementations."""
    if name == "generate_image":
        return generate_image_dalle(
            args["prompt"],
            args.get("size", "1024x1024"),
            args.get("style", "vivid")
        )
    elif name == "generate_image_sdxl":
        return generate_image_sdxl(
            args["prompt"],
            args.get("negative_prompt", ""),
            args.get("style_preset", "photographic")
        )
    elif name == "analyze_image":
        return analyze_image(args["image_url"], args["question"])
    elif name == "refine_prompt":
        return refine_prompt(
            args["user_description"],
            args.get("style_preference", "")
        )
    return f"Unknown tool: {name}"
 
 
# --- Agent ---
SYSTEM_PROMPT = """You are a multi-modal creative agent that helps users generate
and refine images and videos. You have access to multiple generation models.
 
Your workflow:
1. Understand what the user wants to create
2. Use refine_prompt to expand vague descriptions into detailed generation prompts
3. Choose the appropriate generation tool (DALL-E for photorealism, SDXL for artistic styles)
4. After generation, use analyze_image to evaluate the result
5. If the result needs improvement, iterate with adjusted prompts
6. Present the result to the user with a description of what was generated
 
Tips for great prompts:
- Be specific about composition, lighting, and style
- Include the medium (photograph, oil painting, 3D render, etc.)
- Specify camera angle if relevant (close-up, wide shot, aerial view)
- Include mood and atmosphere keywords
- For logos/designs, specify that text should be clear and readable
 
Always explain your creative decisions to the user."""
 
 
def creative_agent(user_request: str):
    """Run the multi-modal generation agent."""
    messages = [{"role": "user", "content": user_request}]
 
    for step in range(10):  # max 10 tool-use steps
        response = claude.messages.create(
            model="claude-sonnet-4-6",
            max_tokens=4096,
            system=SYSTEM_PROMPT,
            tools=tools,
            messages=messages,
        )
 
        messages.append({"role": "assistant", "content": response.content})
 
        if response.stop_reason == "tool_use":
            tool_results = []
            for block in response.content:
                if block.type == "tool_use":
                    print(f"  Tool: {block.name}({json.dumps(block.input)[:100]}...)")
                    result = execute_tool(block.name, block.input)
                    print(f"  Result: {result[:200]}...")
                    tool_results.append({
                        "type": "tool_result",
                        "tool_use_id": block.id,
                        "content": result
                    })
            messages.append({"role": "user", "content": tool_results})
        else:
            # Final response
            text = "".join(
                block.text for block in response.content if hasattr(block, "text")
            )
            return text
 
    return "Agent reached maximum steps."
 
 
# --- Run it ---
if __name__ == "__main__":
    result = creative_agent(
        "Create a logo for a coffee shop called 'Bean There, Done That'. "
        "The style should be warm and inviting, with a vintage feel."
    )
    print(result)

Iterative Refinement Loop

The real power of an agent-based approach is iterative refinement. The agent can generate, analyze, critique, and regenerate:

def iterative_generation(
    user_request: str,
    max_iterations: int = 3
) -> list:
    """
    Generate an image, analyze it, and iterate until satisfied.
    Returns the full history of generations.
    """
    history = []
 
    # Initial prompt refinement
    refined_prompt = refine_prompt(user_request)
    print(f"Refined prompt: {refined_prompt[:100]}...")
 
    for iteration in range(max_iterations):
        print(f"\n--- Iteration {iteration + 1} ---")
 
        # Generate
        result = json.loads(generate_image_dalle(refined_prompt))
        image_url = result["image_url"]
        print(f"Generated image: {image_url[:80]}...")
 
        # Analyze
        analysis = analyze_image(
            image_url,
            f"Evaluate this image against the original request: '{user_request}'. "
            f"Score it 1-10 on: accuracy to request, visual quality, composition. "
            f"List specific issues that should be fixed. "
            f"If it scores 8+ on all criteria, say 'APPROVED'."
        )
        print(f"Analysis: {analysis[:200]}...")
 
        history.append({
            "iteration": iteration + 1,
            "prompt": refined_prompt,
            "image_url": image_url,
            "analysis": analysis
        })
 
        # Check if approved
        if "APPROVED" in analysis.upper():
            print("Image approved!")
            break
 
        # Refine prompt based on critique
        refinement = claude.messages.create(
            model="claude-sonnet-4-6",
            max_tokens=500,
            messages=[{
                "role": "user",
                "content": f"""The previous image generation wasn't quite right.
 
Original request: {user_request}
Prompt used: {refined_prompt}
Critique: {analysis}
 
Write an improved prompt that addresses the specific issues mentioned
in the critique. Return ONLY the new prompt."""
            }]
        )
        refined_prompt = refinement.content[0].text
        print(f"New prompt: {refined_prompt[:100]}...")
 
    return history

What You Should Know After Reading This

  1. What are the four main approaches to image generation (VAE, GAN, auto-regressive, diffusion)?
  2. How does the VAE reparameterization trick work and why is it necessary?
  3. How does GAN training work and why is it unstable?
  4. How do VQ-VAEs turn images into discrete tokens for auto-regressive generation?
  5. What is the diffusion forward process (adding noise) and reverse process (denoising)?
  6. What is latent diffusion and why is it more efficient than pixel-space diffusion?
  7. What is the difference between U-Net and DiT architectures for diffusion?
  8. How does cross-attention enable text conditioning in T2I models?
  9. What is classifier-free guidance (CFG) and how does it improve text-image alignment?
  10. What are FID, IS, and CLIP score, and what does each measure?
  11. How does a 3D VAE compress video spatially and temporally?
  12. What is factored attention and why is it necessary for video generation?
  13. What is the progressive training strategy for T2V models?
  14. How would you build an agent that uses image generation as a tool?

Further Reading

  • "Auto-Encoding Variational Bayes" (Kingma and Welling, 2013) — The original VAE paper
  • "Generative Adversarial Nets" (Goodfellow et al., 2014) — The original GAN paper
  • "Denoising Diffusion Probabilistic Models" (Ho et al., 2020) — DDPM, the foundational diffusion paper
  • "High-Resolution Image Synthesis with Latent Diffusion Models" (Rombach et al., 2022) — Stable Diffusion / Latent Diffusion
  • "Scalable Diffusion Models with Transformers" (Peebles and Xie, 2023) — The DiT paper
  • "Denoising Diffusion Implicit Models" (Song et al., 2020) — DDIM fast sampling
  • "Classifier-Free Diffusion Guidance" (Ho and Salimans, 2022) — CFG
  • "Neural Discrete Representation Learning" (van den Oord et al., 2017) — VQ-VAE
  • "Video generation models as world simulators" (OpenAI, 2024) — Sora technical report
  • "CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer" (Yang et al., 2024)
  • "Stable Video Diffusion: Scaling Latent Video Diffusion Models to Large Datasets" (Blattmann et al., 2023)
  • "Understanding Diffusion Models: A Unified Perspective" (Luo, 2022) — Excellent tutorial paper

You might also like