๋™์•„๋ฆฌ,ํ•™ํšŒ/GDGoC

[AI ์Šคํ„ฐ๋””] Section 11 : GAN (Generative Adversarial Network)

egahyun 2024. 12. 27. 03:44

GAN (์ ๋Œ€์  ์ƒ์„ฑ ๋ชจ๋ธ) ๋ฐœ์ „

  1. ์ตœ์ดˆ ์ œ์•ˆ : (2014) Ian Goodfellow
  2. ์‚ฌ์šฉ : Computer ๊ฐ€ ์ด๋ฏธ์ง€, ์ธ๊ฐ„์˜ ๋ชฉ์†Œ๋ฆฌ, ์•…๊ธฐ์†Œ๋ฆฌ ๋“ฑ์„ ์‹ค์ œ์™€ ๊ฐ™์ด ์ƒ์„ฑ
    https://www.thispersondoesnotexist.com/ : GAN ์‚ฐ์ถœ๋ฌผ์˜ ํ’ˆ์งˆ ๋ณ€ํ™”๋ฅผ ๋ณผ ์ˆ˜ ์žˆ๋Š” ๊ฐ€์ƒ ์ธ๋ฌผ
  3. ๊ทน์ฐฌ์„ ๋ฐ›์•˜๋˜ ์ด์œ 
    : ๋‘ ๊ฐœ์˜ ๋”ฅ ๋‰ด๋Ÿด ๋„คํŠธ์›Œํฌ๋กœ ์ด๋ฃจ์–ด์ง : ์ƒ์„ฑ์ž / ํŒ๋ณ„์ž ⇒ ๋‘ ๋„คํŠธ์›Œํฌ๊ฐ€ ์„œ๋กœ ์ ๋Œ€์ ์œผ๋กœ ์ƒ์„ฑํ•˜๋Š” ๋ชจ๋ธ

Probability Basics

  1. ์ด์‚ฐํ™•๋ฅ  : ๊ฐ๊ฐ์ด ๋”ฑ๋”ฑ ๋–จ์–ด์ง„ ๊ฐ’์„ ๊ฐ€์ง€๋Š” ๊ฒฝ์šฐ ⇒ ex) ์ฃผ์‚ฌ์œ„
  2. ์—ฐ์†๋ถ„ํฌํ™•๋ฅ  : ์—ฐ์†๋œ ์ˆซ์ž์˜ ๊ฐ’์„ ๊ฐ€์ง€๋Š” ๊ฒฝ์šฐ ⇒ ex) ์ด๋ฏธ์ง€ : 64x64x3์˜ ํ™•๋ฅ  ๋ถ„ํฌ

  3. ์ด๋ฏธ์ง€ ํŠน์„ฑ์— ๋”ฐ๋ฅธ ํ™•๋ฅ  ๋ถ„ํฌ
    : ์ด๋ฏธ์ง€ ํŠน์„ฑ์— ๋”ฐ๋ผ์„œ ํ”ฝ์…€ ๊ฐ’๋“ค์˜ ํ™•๋ฅ  ๋ถ„ํฌ๊ฐ€ ๋‹ค ๋‹ค๋ฆ„
    ⇒ ํ”ฝ์…€ ํ•˜๋‚˜ํ•˜๋‚˜๊ฐ€ ์–ด๋–ค ์ด๋ฏธ์ง€์˜ ํŠน์„ฑ๊ณผ ์—ฐ๊ด€๋จ
    ⇒ ex) ํ”ฝ์…€๋งˆ๋‹ค ํ”ผ๋ถ€, ์•ˆ๊ฒฝ ๋“ฑ ๋‹ค ๋‹ค๋ฅด๊ฒŒ ๊ตฌ์„ฑ๋˜์–ด, ์šฐ๋ฆฌ ๋ˆˆ์— ๋‹ค๋ฅธ ๋ชจ์–‘์˜ ์‚ฌ์ง„์œผ๋กœ ๋‚˜ํƒ€๋‚จ

  4. ์ด๋ฏธ์ง€ ํŠน์„ฑ์— ๋”ฐ๋ฅธ ํ™•๋ฅ  ๋‹ค๋ณ€์ˆ˜ ํ™•๋ฅ  ๋ถ„ํฌ
    : ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์˜ ๋ถ€๋ถ„์— ์‹ค์ œ ์ด๋ฏธ์ง€์™€ ๊ฐ€์žฅ ๋น„์Šทํ•œ ์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋  ๊ฒƒ

์ƒ์„ฑ ๋ชจ๋ธ

: ์‹ค์ œ ์กด์žฌํ•˜์ง€๋Š” ์•Š์ง€๋งŒ ์žˆ์„ ๋ฒ•ํ•œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋Š” ๋ชจ๋ธ

  1. ๋ถ„๋ฅ˜๋ชจ๋ธ๊ณผ์˜ ์ฐจ์ด
    • ๋ถ„๋ฅ˜ ๋ชจ๋ธ : ๊ฒฐ์ • ๊ฒฝ๊ณ„๋ฅผ ํ•™์Šต
    • ์ƒ์„ฑ ๋ชจ๋ธ : ๊ฐ ํด๋ž˜์Šค์˜ ๋ถ„ํฌ๋ฅผ ํ•™์Šต ⇒ ๊ฒฐํ•ฉํ™•๋ฅ  (joint probability) ์„ ํ•™์Šต
  2. ๊ฒฐํ•ฉํ™•๋ฅ 
    ⇒ ๋ฌด์—‡์„ ํ•™์Šตํ•˜๋А๋ƒ์˜ ๋Œ€์ƒ์— ๋”ฐ๋ผ ์ด๋ฏธ์ง€ ํ€„๋ฆฌํ‹ฐ๊ฐ€ ๋‹ฌ๋ผ์ง
    • ํ”ฝ์…€์˜ ๋ถ„ํฌ ์ž์ฒด๊ฐ€ ๋‚ฎ์€ ํ™•๋ฅ ์˜ ๋ถ„ํฌ๋ฅผ ํ•™์Šต ์‹œ, ์–ด์ƒ‰ํ•œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ
    • ๋†’์€ ํ™•๋ฅ ์˜ ๋ถ„ํฌ๋ฅผ ํ•™์Šต ์‹œ, ๊ทธ๋Ÿด๋“ฏํ•œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑ

 

GAN์˜ ๋ชฉํ‘œ

: ๋‘๊ฐœ ํ™•๋ฅ  ๋ถ„ํฌ์˜ ์ฐจ์ด๋ฅผ ์ค„์—ฌ ์ฃผ๋Š” ๊ฒƒ

  1. ํ™•๋ฅ ๋ถ„ํฌ : ์‹ค์ œ ์ด๋ฏธ์ง€์˜ ๊ฐ ํ”ฝ์…€์ด ์–ด๋–ค ๊ฐ’์„ ๊ฐ€์ง€๊ณ  ์žˆ์„ ๊ฒƒ์ธ๊ฐ€
    ⇒ ํŠน์ง• : ์ธ์ ‘ํ•œ ํ”ฝ์…€๋“ค ๋ผ๋ฆฌ ์—ฐ๊ด€์ด ๋˜์–ด์žˆ์Œ (ex : ํ”ฝ์…€์ด ๋‹ฌ๋ผ์ง„๋‹ค๊ณ  ํ”ผ๋ถ€์ƒ‰์ด ๋…ธ๋ž€์ƒ‰์ด ์ดˆ๋ก์ƒ‰์ด ๋˜์ง€ ์•Š์Œ)

  2. ํ™•๋ฅ  ๋ชจ๋ธ : ๋ชจ๋ธ์ด ์ƒ์„ฑํ•œ ์ด๋ฏธ์ง€์˜ ํ™•๋ฅ ๋ถ„ํฌ
    ⇒ ํ™•๋ฅ  ๋ฐ์ดํ„ฐ X : ์‹ค์ œ ๋ฐ์ดํ„ฐ

 

GAN ํ•™์Šต

 

๊ฒ€์€ ์ ์„  (a) : ์›๋ณธ ๋ฐ์ดํ„ฐ ์ด๋ฏธ์ง€์˜ ๋ถ„ํฌ (\(p_{data}(x)\)) → ๋ฐ์ดํ„ฐ๊ฐ€ ์œ ํ•œํ•˜๋ฏ€๋กœ ์ ์œผ๋กœ ํ‘œ์‹œ
ํŒŒ๋ž€ ์ ์„  (b) : discriminator distribution
๋…น์ƒ‰ ์„  (c) : generator distribution ( \( p_z(z) \) )
  ⇒ x(real), z(fake) ์„  : ๊ฐ๊ฐ x, z์˜ ๋„๋ฉ”์ธ
  ⇒ ์œ„๋กœ ๋ป—์€ ํ™”์‚ดํ‘œ : \( x=G(z) \) ์˜ ๋งคํ•‘

  1. ( a )
    • a : ์‹ค์ œ ๋ฐ์ดํ„ฐ๊ฐ€ ๊ฐ€์žฅ ๋งŽ์ด ๋ชจ์ธ ๊ณณ
  2. ( b )
    • ์ดˆ๊ธฐ์˜ ์ƒ์„ฑ์ž : ์˜ค๋ฅธ์ชฝ์œผ๋กœ ์ ๋ ค์„œ ๋น„์ •์ƒ์ ์ธ ์‚ฌ๋žŒ์ฒ˜๋Ÿผ ๋ณด์ด๋Š” ๋ถ„ํฌ๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ๋งŒ๋“ฌ
    • ์ƒ์„ฑ์ž : ์™ผ์ชฝ์œผ๋กœ ๋ถ„ํฌ๊ฐ€ ์˜ฎ๊ฒจ์ง
  3. ( c )
    • ํŒ๋ณ„์ž : ์ง„์งœ, ๊ฐ€์งœ๋ฅผ ์ž˜ ์•Œ์•„๋งž์ถค
    • ์ƒ์„ฑ์ž๊ฐ€ ๋งŒ๋“ค์–ด๋‚ธ ์ด๋ฏธ์ง€ ๋ถ„ํฌ๊ฐ€ ์ ์  ์„ผํ„ฐ๋กœ ์˜ฎ๊ฒจ๊ฐ
  4. ( d )
    • ์‹ค์ œ ๋ฐ์ดํ„ฐ์™€ ์ƒ์„ฑ์ž๊ฐ€ ๋งŒ๋“ค์–ด๋‚ธ ๋ฐ์ดํ„ฐ์™€ ์ผ์น˜ํ•œ ์ƒํ™ฉ ⇒ ์‹ค์ œ ๋ฐ์ดํ„ฐ์˜ ๋ถ„ํฌ๋ฅผ ์™„๋ฒฝํžˆ ํ•™์Šตํ•จ
    • ์ƒ์„ฑ์ž : ๋ถ„ํฌ๋ฅผ ํ‰๋‚ด๋‚ด, ๋˜‘๊ฐ™์ด ํ”ฝ์…€ ๊ฐ’์„ ๋งŒ๋“ฌ
      ⇒ ํŒ๋ณ„์ž : ํ™•๋ฅ ์„ 0.5๋กœ ๋‚ด๋ณด๋ƒ„ (๋ชจ๋ฅด๊ฒ ๋‹ค. ๋ผ๋Š” ์˜๋ฏธ๋กœ !)

Generative vs. Discriminative Algorithms

  1. Discriminative Algorithm
    : ์ธํ’‹ ๋ฐ์ดํ„ฐ์˜ ํ”ผ์ฒ˜๋ฅผ ๊ธฐ์ค€์œผ๋กœ ๋ ˆ์ด๋ธ” ์˜ˆ์ธก (ex. Spam ๋ถ„๋ฅ˜)
    → ์ผ๋ฐ˜์ ์ธ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์ด๋‹ค.
    → p(y | X) → “the probability of y given X

  2. Generative Algorithm
    : ์ฃผ์–ด์ง„ ๋ ˆ์ด๋ธ”์„ ๊ธฐ์ค€์œผ๋กœ feature ์˜ˆ์ธก (feature extraction (x) → feature filling(o))
    → p(X | y) → “the probability of features given y”

GAN process (๊ณผ์ •)

  1. Training of Discriminator
    • Generator : ๋‚œ์ˆ˜๋ฅผ ๋ฐœ์ƒ์‹œ์ผœ random image (fake image) ์ƒ์„ฑ
    • Generator๊ฐ€ ์ƒ์„ฑํ•œ ๊ฐ€์งœ ์ด๋ฏธ์ง€(0)์™€ actual dataset ์˜ ์ง„์งœ ์ด๋ฏธ์ง€(1)๋ฅผ ๊ฐ ๋ ˆ์ด๋ธ”์„ ๋ถ€์—ฌํ•ด, discriminator ์—๊ฒŒ ๊ณต๊ธ‰
    • Discriminator : ์ง„์งœ image๋Š” 1, ๊ฐ€์งœ image๋Š” 0๋ฅผ ์ถœ๋ ฅํ•˜๋„๋ก ์ด์ง„ ๋ถ„๋ฅ˜ ํ›ˆ๋ จ
    • Binary Classification Problem ์œผ๋กœ discriminator ๊ต์œก
  2. Training of Generator
    ( generator์˜ ๋ชฉ์  : discriminator๋ฅผ ์†์ด๊ธฐ)
    • Discriminator ์ถœ๋ ฅ์˜ crossentropy ๊ฐ’์„ 1 ๊ณผ ๋น„๊ตํ•ด, ์ฐจ์ด๋ถ„์„ ์†์‹ค๋กœ ์ธ์‹ํ•˜์—ฌ ์—ญ์ „ํŒŒ๋กœ ๋ณด์ •
      ⇒ ๋งŒ๋“  ๊ฐ€์งœ ์ด๋ฏธ์ง€๋ฅผ 0์œผ๋กœ ์˜ˆ์ธก์‹œ, backpropagation์œผ๋กœ ์†์‹ค์„ ์ค„์ด๊ธฐ ์œ„ํ•ด ๊ฐ€์ค‘์น˜ ์กฐ์ •
      ⇒ ์†์‹ค์ด ์ ์  ์ค„์–ด๋“ค๋ฉด, ๊ฐ€์งœ์™€ ์ง„์งœ๋ฅผ ๊ตฌ๋ถ„ํ•˜์ง€ ๋ชปํ•˜๊ณ  1์„ ์ถœ๋ ฅํ•˜๊ฒŒ ๋  ๊ฒƒ

GAN ๋ชฉ์ ํ•จ์ˆ˜ (object function)

: ๋‘ ๊ฐœ์˜ ๋„คํŠธ์›Œํฌ๋ฅผ ๋™์‹œ์— ํ›ˆ๋ จ → ํ•˜๋‚˜๋Š” ๊ณ ์ • ํ•˜๊ณ  ํ›ˆ๋ จํ•˜๊ฒŒ ๋จ

 

x : real data๋กœ ๋ถ€ํ„ฐ ํ‘œ๋ณธ ์ถ”์ถœ

z : N(0,1)๋กœ ๋ถ€ํ„ฐ ํ‘œ๋ณธ์ถ”์ถœ (๋…ธ์ด์ฆˆ)

D(x) : ๊ฐ๋ณ„์ž๊ฐ€ x๊ฐ€ ์ง„์งœ๋ผ๊ณ  ์˜ˆ์ธกํ•œ ํ™•๋ฅ 

G(z) : z๋กœ๋ถ€ํ„ฐ ์ƒ์„ฑ์ž๊ฐ€ ์ƒ์„ฑํ•œ ๊ฐ€์งœ ์ด๋ฏธ์ง€

D(G(z)) : ๊ฐ๋ณ„์ž๊ฐ€ ๊ฐ€์งœ๋ฅผ ์ง„์งœ๋ผ๊ณ  ์˜ˆ์ธกํ•œ ํ™•๋ฅ 

Ex : ๋ชจ๋“  real data instances ์— ๋Œ€ํ•œ ๊ธฐ๋Œ€๊ฐ’

Ez : ์ƒ์„ฑ์ž๋กœ์˜ random ์ž…๋ ฅ์˜ ๊ธฐ๋Œ€๊ฐ’

  1. Discriminator (D(x) classifier)
    D(x) & 1-D(G(z)) ๊ฐ€ 1์ด ๋  ์ˆ˜๋ก \(max_D V(D)\) : ๋‘ ๊ฐœ๋ฅผ ๋‹ค maximize ํ•ด์•ผ ํ•จ

    • D(x) : real ์ด๋ฉด 1, fake ๋ฉด 0 ์„ return ํ•˜๋„๋ก ํ›ˆ๋ จ
      D(x) = 1 ์— ๊ฐ€๊นŒ์šธ์ˆ˜๋ก ์ข‹์Œ ⇒ ์ขŒ์ธก ํ•ญ ($log(x))$ : ๊ฐ’์ด ํด ์ˆ˜๋ก ์ข‹์Œ

    • D(G(z)) : G(z) = real ์ด๋ผ๊ณ  ํŒ๋‹จ → 1, fake๋กœ ํŒ๋‹จ → 0 return
      • 0์— ๊ฐ€๊นŒ์šธ ์ˆ˜๋ก ์ข‹์Œ ⇒ ์šฐ์ธก ํ•ญ ($log(1-D(G(z))$) : ๊ฐ’์ด ํด ์ˆ˜๋ก ์ข‹์Œ
      • Loss ์ธก์ • : cross-entropy → p log(q) ๋ฅผ ์ด์šฉ
  2. Generator
    1-D(G(z)) ์ด 0 ๊ฐ€ ๋˜์–ด์•ผ ์ข‹์Œ ((D(x) ๋Š” G ์™€ ๋ฌด๊ด€ํ•˜๋ฏ€๋กœ ๋ฌด์‹œ)
    D(G(z)) : 1์— ๊ฐ€๊นŒ์šธ ์ˆ˜๋ก ์ข‹์Œ ⇒ ์šฐ์ธก ํ•ญ (\(log(1-D(G(z))\)) : ์ž‘์€ ๊ฐ’ ์ผ์ˆ˜๋ก ์ข‹์Œ

Mode Collapse

: Generator๊ฐ€ Discriminator๋ฅผ ์†์ด๊ธฐ ์‰ฌ์šด ๋ชจ๋“œ๋งŒ ์ƒ์„ฑํ•˜๊ฒŒ ๋˜๋Š” ํ˜„์ƒ

⇒ ex) ๋น„์Šทํ•œ ๊ธ€์ž๋งŒ ์ƒ์„ฑํ•ด ๋‚ด๋Š” ํ˜„์ƒ

 


์‹ค์Šต : GAN ๋ชจ๋ธ ์ž‘์„ฑ ( mnist dataset ์œ„์กฐ )

๋ชจ๋ธ ๊ตฌ์„ฑ ๋ฐ ์ฃผ์˜ ์‚ฌํ•ญ

  1. Discriminator ์˜ Goal
    : mnist dataset ์„ “์ง„์งœ” ๋กœ ์ธ์‹ํ•˜๊ณ , Generator ์—์„œ ๊ณต๊ธ‰๋˜๋Š” image ๋ฅผ fake ๋กœ ๊ตฌ๋ถ„
    → Generator train : Discriminator ๋ฅผ constant ๋กœ freeze ํ•˜์—ฌ gradient๋ฅผ ์•ˆ์ •์ ์œผ๋กœ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•œ๋‹ค

  2. Generator ์˜ Goal
    : discriminator ๊ฐ€ “์ง„์งœ” ๋กœ ์ธ์‹ํ•  fake image ์ƒ์„ฑ (Gaussian random noise ๋กœ ๋ถ€ํ„ฐ image ์ƒ์„ฑ)
    → Discriminator train : Generator ๋ฅผ constant ํ•˜๊ฒŒ freeze

  3. ํ•„์š” ์š”์†Œ
    • Learning Rate ์กฐ์ ˆ ํ•„์š”
      • Discriminator = ๋„ˆ๋ฌด ๊ฐ•ํ•จ
        ⇒ ํ•ญ์ƒ 0 ๊ณผ 1 ์— ๊ทผ์‚ฌํ•œ ๊ฐ’ ๋‚˜์˜ด
        ⇒ Generator ๊ฐ€ gradient ๋ชป ์–ป์Œ
      • Generator = ๋„ˆ๋ฌด smartํ•จ
        ⇒ discriminator ์˜ weakness ๋ฅผ ๊ณ„์† ์ด์šฉ
        ⇒ discriminative๊ฐ€ false negative ๋ฅผ predict ํ•˜๋„๋ก ํ•จ
    • GPU ํ•„์š”
      • GAN ์˜ training ์€ ์‹œ๊ฐ„์ด ์˜ค๋ž˜ ๊ฑธ๋ฆฌ๊ธฐ ๋•Œ๋ฌธ

Utilities

# ๊ฐ€์งœ ์ด๋ฏธ์ง€๊ฐ€ ์–ด๋–ป๊ฒŒ ํ˜•์„ฑ๋˜์—ˆ๋Š”์ง€ ์ถœ๋ ฅํ•˜๋Š” ํ•จ์ˆ˜ 
def plot_multiple_images(images, n_cols=None):
    '''visualizes fake images'''
    display.clear_output(wait=False)  
    n_cols = n_cols or len(images)
    n_rows = (len(images) - 1) // n_cols + 1

    if images.shape[-1] == 1:
        images = np.squeeze(images, axis=-1)

    plt.figure(figsize=(n_cols, n_rows))
    
    for index, image in enumerate(images):
        plt.subplot(n_rows, n_cols, index + 1)
        plt.imshow(image, cmap="binary")
        plt.axis("off")

๋ฐ์ดํ„ฐ ๋‹ค์šด ๋ฐ ์ค€๋น„

# load the train set of the MNIST dataset
(X_train, _), _ = keras.datasets.mnist.load_data()

# normalize pixel values
X_train = X_train.astype(np.float32) / 255

# ํ›ˆ๋ จํ•˜๋Š” ๋™์•ˆ ๋ชจ๋ธ์— ๊ณต๊ธ‰ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ›ˆ๋ จ ์ด๋ฏธ์ง€์˜ ๋ฐฐ์น˜๋ฅผ ์ƒ์„ฑ.
BATCH_SIZE = 128

dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(1000)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True).prefetch(1)
# drop_remainder : ๋”ฑ ๋–จ์–ด์ง€์ง€ ์•Š๋Š” ๋‚˜๋จธ์ง€๋ฅผ drop (60000 / 128 = ๋‚˜๋จธ์ง€๊ฐ€ ์กด์žฌํ•˜๊ธฐ ๋•Œ๋ฌธ)

Generator

: ์ž„์˜์˜ ๋…ธ์ด์ฆˆ๋ฅผ ๋ฐ›์•„ ๊ฐ€์งœ ์ด๋ฏธ์ง€๋ฅผ ๋งŒ๋“œ๋Š” ๋ฐ ์‚ฌ์šฉ

⇒ ๋žœ๋ค ๋…ธ์ด์ฆˆ ํ˜•ํƒœ๋กœ ๋ฐ›์•„์„œ MNIST ๋ฐ์ดํ„ฐ ์„ธํŠธ (์˜ˆ : 28 x 28)์™€ ๋™์ผํ•œ ํฌ๊ธฐ์˜ ์ด๋ฏธ์ง€๋ฅผ ์ถœ๋ ฅ

# declare shape of the noise input
# normal distribution์„ ๋งŒ๋“ค dimension์„ 32์ฐจ์›์œผ๋กœ 
random_normal_dimensions = 32

# generator model
# SELU : GAN์— ์ ํ•ฉํ•œ ํ™œ์„ฑํ™” ํ•จ์ˆ˜๋กœ ReLU ๊ณ„์—ด์ž„ => ์ฒ˜์Œ ๋‘ ๊ฐœ์˜ Dense ๋„คํŠธ์›Œํฌ์—์„œ ์‚ฌ์šฉ 
# sigmoid : ์ตœ์ข… Dense ๋„คํŠธ์›Œํฌ๋Š” 0๊ณผ 1 ์‚ฌ์ด์˜ ํ”ฝ์…€ ๊ฐ’์„ ์ƒ์„ฑํ•˜๊ธฐ ์›ํ•ด์„œ
generator = keras.models.Sequential([                                 
    keras.layers.Dense(64, activation="selu", input_shape=[random_normal_dimensions]),
    keras.layers.Dense(128, activation="selu"),
    keras.layers.Dense(28 * 28, activation="sigmoid"),
    keras.layers.Reshape([28, 28])  # ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ์ฐจ์›์— ๋งž๊ฒŒ reshape
])

# ํ›ˆ๋ จ๋˜์ง€ ์•Š์€ generator์˜ ์ƒ˜ํ”Œ ์ถœ๋ ฅ => random point ๊ทธ ์ž์ฒด๋ฅผ ์ถœ๋ ฅํ•  ๊ฒƒ
# ํ›ˆ๋ จ ํ›„ : MNIST ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ์ˆซ์ž๋ฅผ ๋‹ฎ์„๊ฒƒ 
# batch size = 16๋กœ ๋…ธ์ด์ฆˆ ์ƒ์„ฑ
test_noise = tf.random.normal([16, random_normal_dimensions])

# feed the batch to the untrained generator
test_image = generator(test_noise)

# visualize sample output
plot_multiple_images(test_image, n_cols=4)

Discriminator

: ์ž…๋ ฅ(๊ฐ€์งœ ๋˜๋Š” ์‹ค์ œ) ์ด๋ฏธ์ง€๋ฅผ ๊ฐ€์ ธ์™€ ๊ฐ€์งœ์ธ์ง€ ์•„๋‹Œ์ง€๋ฅผ ๊ฒฐ์ •

# input shape = ํ›ˆ๋ จ ์ด๋ฏธ์ง€์˜ ๋ชจ์–‘
# => flatten => dense network์— ๊ณต๊ธ‰ => ์ถœ๋ ฅ : 0 (๊ฐ€์งœ)๊ณผ 1 (์‹ค์ œ) ์‚ฌ์ด์˜ ๊ฐ’
# build the discriminator model
discriminator = keras.models.Sequential([
    keras.layers.Flatten(input_shape=[28, 28]),
    keras.layers.Dense(128, activation="selu"), # ์ฒ˜์Œ ๋‘ ๊ฐœ์˜ dense network์—์„œ SELU ํ™œ์„ฑํ™”
    keras.layers.Dense(64, activation="selu"),
    keras.layers.Dense(1, activation="sigmoid") # sigmoid๋กœ final network๋ฅผ ํ™œ์„ฑํ™”
])

GAN ๊ตฌ์ถ• ๋ฐ ํ›ˆ๋ จ์„ ์œ„ํ•œ ์ค€๋น„

# Generator, Discriminator ๋‘ ๋ชจ๋ธ์„ ์ถ”๊ฐ€ํ•ด GAN ๊ตฌ์ถ•
gan = keras.models.Sequential([generator, discriminator])

# Configure Training Parameters
# binary_crossentropy๋กœ ์†์‹ค => ์ด์œ  : ๋ผ๋ฒจ์ด 0 (๊ฐ€์งœ) ๋˜๋Š” 1 (์‹ค์ œ)์ด ๋  ๊ฒƒ
discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")
discriminator.trainable = False
gan.compile(loss="binary_crossentropy", optimizer="rmsprop")

๋ชจ๋ธ ํ›ˆ๋ จ

  • 1 ๋‹จ๊ณ„ - ๊ฐ€์งœ ๋ฐ์ดํ„ฐ์™€ ์‹ค์ œ ๋ฐ์ดํ„ฐ๋ฅผ ๊ตฌ๋ถ„ํ•˜๋„๋ก ํŒ๋ณ„์ž๋ฅผ ํ›ˆ๋ จ์‹œํ‚ต๋‹ˆ๋‹ค.
  • 2 ๋‹จ๊ณ„ - ํŒ๋ณ„์ž๋ฅผ ์†์ด๋Š” ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋„๋ก ์ƒ์„ฑ์ž๋ฅผ ํ›ˆ๋ จํ•ฉ๋‹ˆ๋‹ค.

๊ฐ epoch๋งˆ๋‹ค ์ƒ์„ฑ์ž์— ์˜ํ•ด ์ƒ์„ฑ๋˜๋Š” ๊ฐ€์งœ ์ด๋ฏธ์ง€๋ฅผ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด ์ด๋ฏธ์ง€ ์ƒ˜ํ”Œ ๊ฐค๋Ÿฌ๋ฆฌ๋ฅผ ํ‘œ์‹œ

def train_gan(gan, dataset, random_normal_dimensions, n_epochs=100):
    """ Defines the two-phase training loop of the GAN
    Args:
      gan -- the GAN model which has the generator and discriminator
      dataset -- the training set of real images
      random_normal_dimensions -- dimensionality of the input to the generator
      n_epochs -- number of epochs
    """

    # get the two sub networks from the GAN model
    generator, discriminator = gan.layers

    # start loop
    for epoch in range(n_epochs):
        print("Epoch {}/{}".format(epoch + 1, n_epochs))       
        for real_images in dataset: # ๋ฐ์ดํ„ฐ์…‹์—์„œ ์‹ค์ œ ์ด๋ฏธ์ง€๋ฅผ ์ฝ์–ด๋“ค์ž„ 
            # ํ›ˆ๋ จ ๋ฐฐ์น˜์—์„œ ๋ฐฐ์น˜ ํฌ๊ธฐ ์ถ”๋ก 
            batch_size = real_images.shape[0]

            # Train the discriminator - PHASE 1
            # noise ์ƒ์„ฑ
            noise = tf.random.normal(shape=[batch_size, random_normal_dimensions])
            
            # ๋…ธ์ด์ฆˆ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ€์งœ ์ด๋ฏธ์ง€ ์ƒ์„ฑ
            fake_images = generator(noise)
            
            # ๊ฐ€์งœ ์ด๋ฏธ์ง€์™€ ์‹ค์ œ ์ด๋ฏธ์ง€๋ฅผ ์—ฐ๊ฒฐํ•˜์—ฌ list ๋งŒ๋“ค๊ธฐ
            mixed_images = tf.concat([fake_images, real_images], axis=0)
            
            # discriminator๋ฅผ ์œ„ํ•œ label ์ƒ์„ฑ (์ง€๋„ํ•™์Šต์„ ํ•˜๊ธฐ ์œ„ํ•จ)
            # 0 for the fake images
            # 1 for the real images
            discriminator_labels = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
            
            # Ensure that the discriminator is trainable
            discriminator.trainable = True
            
            # train_on_batch๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ mixed_images ์™€ discriminator_labels๋กœ ํŒ๋ณ„์ž๋ฅผ ํ›ˆ๋ จํ•ฉ๋‹ˆ๋‹ค.
            # 0์„ 0์œผ๋กœ ๋งž์ถ”๊ณ , 1์„ 1๋กœ ๋งž์ถค => ์†์‹ค์ด ์—†์Œ
            # ์ž˜๋ชป ๋งž์ถ”๋ฉด backpropagation ๋  ๊ฒƒ
            discriminator.train_on_batch(mixed_images, discriminator_labels)
            # discriminator.fit(mixed_images, discriminator_labels) : fit == train_on_batch
            
            # Train the generator - PHASE 2
            # GAN์— ๊ณต๊ธ‰ํ•  ๋…ธ์ด์ฆˆ ์ž…๋ ฅ ๋ฐฐ์น˜ ์ƒˆ๋กญ๊ฒŒ ์ƒ์„ฑ
            noise = tf.random.normal(shape=[batch_size, random_normal_dimensions])

            # ์ƒ์„ฑ๋œ ๋ชจ๋“  ์ด๋ฏธ์ง€์— "real" ๋ ˆ์ด๋ธ” ์ง€์ •
            generator_labels = tf.constant([[1.]] * batch_size)

            # Freeze the discriminator : ํ›ˆ๋ จ๋œ ๊ฐ๋ณ„์ž ๊ณ ์ •
            discriminator.trainable = False

            # ๋ ˆ์ด๋ธ”์ด ๋ชจ๋‘ true๋กœ ์„ค์ •๋œ ๋…ธ์ด์ฆˆ์— ๋Œ€ํ•œ GAN ํ›ˆ๋ จ
            # generator์˜ ์ถœ๋ ฅ์ด discriminator์˜ ์ž…๋ ฅ์œผ๋กœ ๋“ค์–ด๊ฐ€, 1, 0 ์ค‘ ํ•˜๋‚˜๋กœ ํŒ๋‹จ
            # discriminator๊ฐ€ ์†์•„์„œ 1์„ ์ถœ๋ ฅ์‹œ, generator_labels๊ณผ ๋™์ผํ•˜๋ฏ€๋กœ ์†์‹ค์ด 0 => ์„ฑ๊ณต!
            gan.train_on_batch(noise, generator_labels) # generator_labels = 1
            # gan.fit(noise, generator_labels)

        #ํŒ๋ณ„์ž๋ฅผ ํ›ˆ๋ จํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋Š” ๊ฐ€์งœ ์ด๋ฏธ์ง€๋ฅผ ํ”Œ๋กœํŒ…ํ•ฉ๋‹ˆ๋‹ค.
        plot_multiple_images(fake_images, 8)                     
        plt.show() 
        
# ํ›ˆ๋ จ        
train_gan(gan, dataset, random_normal_dimensions, n_epochs=30)  

Mode Collapse : ํ•™์Šต์ด ์ง„ํ–‰๋จ์— ๋”ฐ๋ผ ๋ชจ๋ธ์ด 1, 7, 9์™€ ๊ฐ™์€ ์ˆซ์ž์˜ ํ•˜์œ„ ์ง‘ํ•ฉ์— ํŽธํ–ฅ๋˜๋Š” ๊ฒฝํ–ฅ