2024
1. Diffusion Models
2. Building up to Stable Diffusion
1. Diffusion Models
ก่อน 2020 คนไม่ค่อยรู้จัก diffusion model เพราะมันปนๆ อยู่ใน ML
– เพื่อ generate synthetic image ได้ดีกว่า model อื่นๆ
2021 -model ที่คล้ายๆ GLIDE , text-to-image -> เข้าสู่ mainstream DALL-E2 and stable diffusion
เพื่อให้เราสามารถ image โดยแค่พิมพ์text
p 6 , Key insight
ก่อนหน้าที่ พูดถึง VAEs or GANs
pre-trained model ใช้ Hugging Face diffusers Library
pipelie สร้าง image
# Load the pipeline
image_pipe = DDPMPipeline.from_pretrained("google/ddpm-celebahq-
256")
image_pipe.to(device);
# Sample an image
image_pipe().images[0]

p 7, re-create sampling process – step by step – model generates images
sample x , c random noise , run ผ่าน 30 steps
rt เราจะเห็น model’s prediction – เพื่อดู final image – ยังดูไม่ดีมาก – jump ไปยัง final predicted image
– only modify x โดย small amount in direction of prediction – เห็นด้าน L
ภาพใหม่ ดีกว่า x
# The random starting point for a batch of 4 images
x = torch.randn(4, 3, 256, 256).to(device)
# Set the number of timesteps lower
image_pipe.scheduler.set_timesteps(num_inference_steps=30)
# Loop through the sampling timesteps
for i, t in enumerate(image_pipe.scheduler.timesteps):
# Get the prediction given the current sample x and the
timestep t
with torch.no_grad():
noise_pred = image_pipe.unet(x, t)["sample"]
# Calculate what the updated sample should look like with the
scheduler
scheduler_output = image_pipe.scheduler.step(noise_pred, t,
x)
# Update x
x = scheduler_output.prev_sample
# Occasionally display both x and the predicted denoised
images
if i % 10 == 0 or i == len(image_pipe.scheduler.timesteps) -
1:
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
grid = torchvision.utils.make_grid(x, nrow=4).permute(1,
2, 0)
axs[0].imshow(grid.cpu().clip(-1, 1) * 0.5 + 0.5)
axs[0].set_title(f"Current x (step {i})")
pred_x0 = scheduler_output.pred_original_sample
grid = torchvision.utils.make_grid(pred_x0,
nrow=4).permute(1, 2, 0)
axs[1].imshow(grid.cpu().clip(-1, 1) * 0.5 + 0.5)
axs[1].set_title(f"Predicted denoised images (step {i})")
plt.show()

p 9 ,
