Generator & discriminator (DCGAN-style sketch)
import torch
import torch.nn as nn
nz = 100 # noise dim
class G(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(nz, 256 * 7 * 7),
nn.Unflatten(1, (256, 7, 7)),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128), nn.ReLU(True),
nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False),
nn.Tanh(),
)
def forward(self, z):
return self.net(z)
class D(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
nn.Flatten(),
nn.Linear(256 * 3 * 3, 1),
)
def forward(self, x):
return self.net(x)
Two stride-2 transposed convs: 7×7→14×14→28×28. D outputs a logit (use BCEWithLogits). Normalize real images to [-1,1] to match Tanh.
One training iteration (non-saturating G loss)
device = "cuda" if torch.cuda.is_available() else "cpu"
G_m, D_m = G().to(device), D().to(device)
optG = torch.optim.Adam(G_m.parameters(), lr=2e-4, betas=(0.5, 0.999))
optD = torch.optim.Adam(D_m.parameters(), lr=2e-4, betas=(0.5, 0.999))
bce = nn.BCEWithLogitsLoss()
def step_D(real):
b = real.size(0)
z = torch.randn(b, nz, device=device)
fake = G_m(z).detach()
loss = bce(D_m(real), torch.ones(b, 1, device=device))
loss += bce(D_m(fake), torch.zeros(b, 1, device=device))
optD.zero_grad()
loss.backward()
optD.step()
return loss.item()
def step_G(b):
z = torch.randn(b, nz, device=device)
out = D_m(G_m(z))
loss = bce(out, torch.ones(b, 1, device=device))
optG.zero_grad()
loss.backward()
optG.step()
return loss.item()
Practical tips
- Train
Dmore often thanGearly on ifDis too weak—or the reverse ifDdominates. - Use spectral norm / WGAN-GP in harder setups for stability.
- Mode collapse: generator outputs limited variety; detect via sample diversity metrics.
Takeaways
- Minimax game between G and D; non-saturating loss common for G.
- DCGAN: strided conv D, transposed conv G, BN, no pooling in core blocks.
- Modern CV often uses diffusion or autoregressive models for higher fidelity—GANs still teach the adversarial idea.
Quick FAQ
Matches [-1,1] normalized images and pairs cleanly with LeakyReLU discriminator; sigmoid output + BCE on [0,1] images is an alternative.
Concatenate class embedding or channel-wise label maps to
z and/or intermediate features so generation is class-controlled.