[AI ์คํฐ๋] Section 11 : GAN (Generative Adversarial Network)
GAN (์ ๋์ ์์ฑ ๋ชจ๋ธ) ๋ฐ์
- ์ต์ด ์ ์ : (2014) Ian Goodfellow
- ์ฌ์ฉ : Computer ๊ฐ ์ด๋ฏธ์ง, ์ธ๊ฐ์ ๋ชฉ์๋ฆฌ, ์
๊ธฐ์๋ฆฌ ๋ฑ์ ์ค์ ์ ๊ฐ์ด ์์ฑ
→ https://www.thispersondoesnotexist.com/ : GAN ์ฐ์ถ๋ฌผ์ ํ์ง ๋ณํ๋ฅผ ๋ณผ ์ ์๋ ๊ฐ์ ์ธ๋ฌผ - ๊ทน์ฐฌ์ ๋ฐ์๋ ์ด์
: ๋ ๊ฐ์ ๋ฅ ๋ด๋ด ๋คํธ์ํฌ๋ก ์ด๋ฃจ์ด์ง : ์์ฑ์ / ํ๋ณ์ ⇒ ๋ ๋คํธ์ํฌ๊ฐ ์๋ก ์ ๋์ ์ผ๋ก ์์ฑํ๋ ๋ชจ๋ธ
Probability Basics
- ์ด์ฐํ๋ฅ : ๊ฐ๊ฐ์ด ๋ฑ๋ฑ ๋จ์ด์ง ๊ฐ์ ๊ฐ์ง๋ ๊ฒฝ์ฐ ⇒ ex) ์ฃผ์ฌ์
- ์ฐ์๋ถํฌํ๋ฅ : ์ฐ์๋ ์ซ์์ ๊ฐ์ ๊ฐ์ง๋ ๊ฒฝ์ฐ ⇒ ex) ์ด๋ฏธ์ง : 64x64x3์ ํ๋ฅ ๋ถํฌ
- ์ด๋ฏธ์ง ํน์ฑ์ ๋ฐ๋ฅธ ํ๋ฅ ๋ถํฌ
: ์ด๋ฏธ์ง ํน์ฑ์ ๋ฐ๋ผ์ ํฝ์ ๊ฐ๋ค์ ํ๋ฅ ๋ถํฌ๊ฐ ๋ค ๋ค๋ฆ
⇒ ํฝ์ ํ๋ํ๋๊ฐ ์ด๋ค ์ด๋ฏธ์ง์ ํน์ฑ๊ณผ ์ฐ๊ด๋จ
⇒ ex) ํฝ์ ๋ง๋ค ํผ๋ถ, ์๊ฒฝ ๋ฑ ๋ค ๋ค๋ฅด๊ฒ ๊ตฌ์ฑ๋์ด, ์ฐ๋ฆฌ ๋์ ๋ค๋ฅธ ๋ชจ์์ ์ฌ์ง์ผ๋ก ๋ํ๋จ - ์ด๋ฏธ์ง ํน์ฑ์ ๋ฐ๋ฅธ ํ๋ฅ ๋ค๋ณ์ ํ๋ฅ ๋ถํฌ
: ๊ฐ์ฅ ๋์ ํ๋ฅ ์ ๋ถ๋ถ์ ์ค์ ์ด๋ฏธ์ง์ ๊ฐ์ฅ ๋น์ทํ ์ด๋ฏธ์ง๊ฐ ์์ฑ๋ ๊ฒ

์์ฑ ๋ชจ๋ธ
: ์ค์ ์กด์ฌํ์ง๋ ์์ง๋ง ์์ ๋ฒํ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ์ ์๋ ๋ชจ๋ธ
- ๋ถ๋ฅ๋ชจ๋ธ๊ณผ์ ์ฐจ์ด
- ๋ถ๋ฅ ๋ชจ๋ธ : ๊ฒฐ์ ๊ฒฝ๊ณ๋ฅผ ํ์ต
- ์์ฑ ๋ชจ๋ธ : ๊ฐ ํด๋์ค์ ๋ถํฌ๋ฅผ ํ์ต ⇒ ๊ฒฐํฉํ๋ฅ (joint probability) ์ ํ์ต
- ๊ฒฐํฉํ๋ฅ

⇒ ๋ฌด์์ ํ์ตํ๋๋์ ๋์์ ๋ฐ๋ผ ์ด๋ฏธ์ง ํ๋ฆฌํฐ๊ฐ ๋ฌ๋ผ์ง - ํฝ์ ์ ๋ถํฌ ์์ฒด๊ฐ ๋ฎ์ ํ๋ฅ ์ ๋ถํฌ๋ฅผ ํ์ต ์, ์ด์ํ ์ด๋ฏธ์ง ์์ฑ
- ๋์ ํ๋ฅ ์ ๋ถํฌ๋ฅผ ํ์ต ์, ๊ทธ๋ด๋ฏํ ์ด๋ฏธ์ง๋ฅผ ์์ฑ
GAN์ ๋ชฉํ
: ๋๊ฐ ํ๋ฅ ๋ถํฌ์ ์ฐจ์ด๋ฅผ ์ค์ฌ ์ฃผ๋ ๊ฒ

- ํ๋ฅ ๋ถํฌ : ์ค์ ์ด๋ฏธ์ง์ ๊ฐ ํฝ์
์ด ์ด๋ค ๊ฐ์ ๊ฐ์ง๊ณ ์์ ๊ฒ์ธ๊ฐ
⇒ ํน์ง : ์ธ์ ํ ํฝ์ ๋ค ๋ผ๋ฆฌ ์ฐ๊ด์ด ๋์ด์์ (ex : ํฝ์ ์ด ๋ฌ๋ผ์ง๋ค๊ณ ํผ๋ถ์์ด ๋ ธ๋์์ด ์ด๋ก์์ด ๋์ง ์์) - ํ๋ฅ ๋ชจ๋ธ : ๋ชจ๋ธ์ด ์์ฑํ ์ด๋ฏธ์ง์ ํ๋ฅ ๋ถํฌ
⇒ ํ๋ฅ ๋ฐ์ดํฐ X : ์ค์ ๋ฐ์ดํฐ
GAN ํ์ต

๊ฒ์ ์ ์ (a) : ์๋ณธ ๋ฐ์ดํฐ ์ด๋ฏธ์ง์ ๋ถํฌ (\(p_{data}(x)\)) → ๋ฐ์ดํฐ๊ฐ ์ ํํ๋ฏ๋ก ์ ์ผ๋ก ํ์
ํ๋ ์ ์ (b) : discriminator distribution
๋
น์ ์ (c) : generator distribution ( \( p_z(z) \) )
⇒ x(real), z(fake) ์ : ๊ฐ๊ฐ x, z์ ๋๋ฉ์ธ
⇒ ์๋ก ๋ป์ ํ์ดํ : \( x=G(z) \) ์ ๋งคํ
- ( a )
- a : ์ค์ ๋ฐ์ดํฐ๊ฐ ๊ฐ์ฅ ๋ง์ด ๋ชจ์ธ ๊ณณ
- ( b )
- ์ด๊ธฐ์ ์์ฑ์ : ์ค๋ฅธ์ชฝ์ผ๋ก ์ ๋ ค์ ๋น์ ์์ ์ธ ์ฌ๋์ฒ๋ผ ๋ณด์ด๋ ๋ถํฌ๋ก ๋ฐ์ดํฐ๋ฅผ ๋ง๋ฌ
- ์์ฑ์ : ์ผ์ชฝ์ผ๋ก ๋ถํฌ๊ฐ ์ฎ๊ฒจ์ง
- ( c )
- ํ๋ณ์ : ์ง์ง, ๊ฐ์ง๋ฅผ ์ ์์๋ง์ถค
- ์์ฑ์๊ฐ ๋ง๋ค์ด๋ธ ์ด๋ฏธ์ง ๋ถํฌ๊ฐ ์ ์ ์ผํฐ๋ก ์ฎ๊ฒจ๊ฐ
- ( d )
- ์ค์ ๋ฐ์ดํฐ์ ์์ฑ์๊ฐ ๋ง๋ค์ด๋ธ ๋ฐ์ดํฐ์ ์ผ์นํ ์ํฉ ⇒ ์ค์ ๋ฐ์ดํฐ์ ๋ถํฌ๋ฅผ ์๋ฒฝํ ํ์ตํจ
- ์์ฑ์ : ๋ถํฌ๋ฅผ ํ๋ด๋ด, ๋๊ฐ์ด ํฝ์
๊ฐ์ ๋ง๋ฌ
⇒ ํ๋ณ์ : ํ๋ฅ ์ 0.5๋ก ๋ด๋ณด๋ (๋ชจ๋ฅด๊ฒ ๋ค. ๋ผ๋ ์๋ฏธ๋ก !)
Generative vs. Discriminative Algorithms
- Discriminative Algorithm
: ์ธํ ๋ฐ์ดํฐ์ ํผ์ฒ๋ฅผ ๊ธฐ์ค์ผ๋ก ๋ ์ด๋ธ ์์ธก (ex. Spam ๋ถ๋ฅ)
→ ์ผ๋ฐ์ ์ธ ๋ถ๋ฅ ๋ชจ๋ธ์ด๋ค.
→ p(y | X) → “the probability of y given X - Generative Algorithm
: ์ฃผ์ด์ง ๋ ์ด๋ธ์ ๊ธฐ์ค์ผ๋ก feature ์์ธก (feature extraction (x) → feature filling(o))
→ p(X | y) → “the probability of features given y”
GAN process (๊ณผ์ )


- Training of Discriminator
- Generator : ๋์๋ฅผ ๋ฐ์์์ผ random image (fake image) ์์ฑ
- Generator๊ฐ ์์ฑํ ๊ฐ์ง ์ด๋ฏธ์ง(0)์ actual dataset ์ ์ง์ง ์ด๋ฏธ์ง(1)๋ฅผ ๊ฐ ๋ ์ด๋ธ์ ๋ถ์ฌํด, discriminator ์๊ฒ ๊ณต๊ธ
- Discriminator : ์ง์ง image๋ 1, ๊ฐ์ง image๋ 0๋ฅผ ์ถ๋ ฅํ๋๋ก ์ด์ง ๋ถ๋ฅ ํ๋ จ
- Binary Classification Problem ์ผ๋ก discriminator ๊ต์ก
- Training of Generator
( generator์ ๋ชฉ์ : discriminator๋ฅผ ์์ด๊ธฐ)
- Discriminator ์ถ๋ ฅ์ crossentropy ๊ฐ์ 1 ๊ณผ ๋น๊ตํด, ์ฐจ์ด๋ถ์ ์์ค๋ก ์ธ์ํ์ฌ ์ญ์ ํ๋ก ๋ณด์
⇒ ๋ง๋ ๊ฐ์ง ์ด๋ฏธ์ง๋ฅผ 0์ผ๋ก ์์ธก์, backpropagation์ผ๋ก ์์ค์ ์ค์ด๊ธฐ ์ํด ๊ฐ์ค์น ์กฐ์
⇒ ์์ค์ด ์ ์ ์ค์ด๋ค๋ฉด, ๊ฐ์ง์ ์ง์ง๋ฅผ ๊ตฌ๋ถํ์ง ๋ชปํ๊ณ 1์ ์ถ๋ ฅํ๊ฒ ๋ ๊ฒ
- Discriminator ์ถ๋ ฅ์ crossentropy ๊ฐ์ 1 ๊ณผ ๋น๊ตํด, ์ฐจ์ด๋ถ์ ์์ค๋ก ์ธ์ํ์ฌ ์ญ์ ํ๋ก ๋ณด์
GAN ๋ชฉ์ ํจ์ (object function)
: ๋ ๊ฐ์ ๋คํธ์ํฌ๋ฅผ ๋์์ ํ๋ จ → ํ๋๋ ๊ณ ์ ํ๊ณ ํ๋ จํ๊ฒ ๋จ

x : real data๋ก ๋ถํฐ ํ๋ณธ ์ถ์ถ
z : N(0,1)๋ก ๋ถํฐ ํ๋ณธ์ถ์ถ (๋ ธ์ด์ฆ)
D(x) : ๊ฐ๋ณ์๊ฐ x๊ฐ ์ง์ง๋ผ๊ณ ์์ธกํ ํ๋ฅ
G(z) : z๋ก๋ถํฐ ์์ฑ์๊ฐ ์์ฑํ ๊ฐ์ง ์ด๋ฏธ์ง
D(G(z)) : ๊ฐ๋ณ์๊ฐ ๊ฐ์ง๋ฅผ ์ง์ง๋ผ๊ณ ์์ธกํ ํ๋ฅ
Ex : ๋ชจ๋ real data instances ์ ๋ํ ๊ธฐ๋๊ฐ
Ez : ์์ฑ์๋ก์ random ์
๋ ฅ์ ๊ธฐ๋๊ฐ
- Discriminator (D(x) classifier)

D(x) & 1-D(G(z)) ๊ฐ 1์ด ๋ ์๋ก \(max_D V(D)\) : ๋ ๊ฐ๋ฅผ ๋ค maximize ํด์ผ ํจ
- D(x) : real ์ด๋ฉด 1, fake ๋ฉด 0 ์ return ํ๋๋ก ํ๋ จ
D(x) = 1 ์ ๊ฐ๊น์ธ์๋ก ์ข์ ⇒ ์ข์ธก ํญ ($log(x))$ : ๊ฐ์ด ํด ์๋ก ์ข์ - D(G(z)) : G(z) = real ์ด๋ผ๊ณ ํ๋จ → 1, fake๋ก ํ๋จ → 0 return
- 0์ ๊ฐ๊น์ธ ์๋ก ์ข์ ⇒ ์ฐ์ธก ํญ ($log(1-D(G(z))$) : ๊ฐ์ด ํด ์๋ก ์ข์
- Loss ์ธก์ : cross-entropy → p log(q) ๋ฅผ ์ด์ฉ
- D(x) : real ์ด๋ฉด 1, fake ๋ฉด 0 ์ return ํ๋๋ก ํ๋ จ
- Generator
D(G(z)) : 1์ ๊ฐ๊น์ธ ์๋ก ์ข์ ⇒ ์ฐ์ธก ํญ (\(log(1-D(G(z))\)) : ์์ ๊ฐ ์ผ์๋ก ์ข์
1-D(G(z)) ์ด 0 ๊ฐ ๋์ด์ผ ์ข์ ((D(x) ๋ G ์ ๋ฌด๊ดํ๋ฏ๋ก ๋ฌด์)
Mode Collapse
: Generator๊ฐ Discriminator๋ฅผ ์์ด๊ธฐ ์ฌ์ด ๋ชจ๋๋ง ์์ฑํ๊ฒ ๋๋ ํ์
⇒ ex) ๋น์ทํ ๊ธ์๋ง ์์ฑํด ๋ด๋ ํ์


์ค์ต : GAN ๋ชจ๋ธ ์์ฑ ( mnist dataset ์์กฐ )
๋ชจ๋ธ ๊ตฌ์ฑ ๋ฐ ์ฃผ์ ์ฌํญ
- Discriminator ์ Goal
: mnist dataset ์ “์ง์ง” ๋ก ์ธ์ํ๊ณ , Generator ์์ ๊ณต๊ธ๋๋ image ๋ฅผ fake ๋ก ๊ตฌ๋ถ
→ Generator train : Discriminator ๋ฅผ constant ๋ก freeze ํ์ฌ gradient๋ฅผ ์์ ์ ์ผ๋ก ๊ณ์ฐํ ์ ์๋๋ก ํ๋ค - Generator ์ Goal
: discriminator ๊ฐ “์ง์ง” ๋ก ์ธ์ํ fake image ์์ฑ (Gaussian random noise ๋ก ๋ถํฐ image ์์ฑ)
→ Discriminator train : Generator ๋ฅผ constant ํ๊ฒ freeze - ํ์ ์์
- Learning Rate ์กฐ์ ํ์
- Discriminator = ๋๋ฌด ๊ฐํจ
⇒ ํญ์ 0 ๊ณผ 1 ์ ๊ทผ์ฌํ ๊ฐ ๋์ด
⇒ Generator ๊ฐ gradient ๋ชป ์ป์ - Generator = ๋๋ฌด smartํจ
⇒ discriminator ์ weakness ๋ฅผ ๊ณ์ ์ด์ฉ
⇒ discriminative๊ฐ false negative ๋ฅผ predict ํ๋๋ก ํจ
- Discriminator = ๋๋ฌด ๊ฐํจ
- GPU ํ์
- GAN ์ training ์ ์๊ฐ์ด ์ค๋ ๊ฑธ๋ฆฌ๊ธฐ ๋๋ฌธ
- Learning Rate ์กฐ์ ํ์
Utilities
# ๊ฐ์ง ์ด๋ฏธ์ง๊ฐ ์ด๋ป๊ฒ ํ์ฑ๋์๋์ง ์ถ๋ ฅํ๋ ํจ์
def plot_multiple_images(images, n_cols=None):
'''visualizes fake images'''
display.clear_output(wait=False)
n_cols = n_cols or len(images)
n_rows = (len(images) - 1) // n_cols + 1
if images.shape[-1] == 1:
images = np.squeeze(images, axis=-1)
plt.figure(figsize=(n_cols, n_rows))
for index, image in enumerate(images):
plt.subplot(n_rows, n_cols, index + 1)
plt.imshow(image, cmap="binary")
plt.axis("off")
๋ฐ์ดํฐ ๋ค์ด ๋ฐ ์ค๋น
# load the train set of the MNIST dataset
(X_train, _), _ = keras.datasets.mnist.load_data()
# normalize pixel values
X_train = X_train.astype(np.float32) / 255
# ํ๋ จํ๋ ๋์ ๋ชจ๋ธ์ ๊ณต๊ธํ ์ ์๋๋ก ํ๋ จ ์ด๋ฏธ์ง์ ๋ฐฐ์น๋ฅผ ์์ฑ.
BATCH_SIZE = 128
dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(1000)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True).prefetch(1)
# drop_remainder : ๋ฑ ๋จ์ด์ง์ง ์๋ ๋๋จธ์ง๋ฅผ drop (60000 / 128 = ๋๋จธ์ง๊ฐ ์กด์ฌํ๊ธฐ ๋๋ฌธ)
Generator
: ์์์ ๋ ธ์ด์ฆ๋ฅผ ๋ฐ์ ๊ฐ์ง ์ด๋ฏธ์ง๋ฅผ ๋ง๋๋ ๋ฐ ์ฌ์ฉ
⇒ ๋๋ค ๋ ธ์ด์ฆ ํํ๋ก ๋ฐ์์ MNIST ๋ฐ์ดํฐ ์ธํธ (์ : 28 x 28)์ ๋์ผํ ํฌ๊ธฐ์ ์ด๋ฏธ์ง๋ฅผ ์ถ๋ ฅ
# declare shape of the noise input
# normal distribution์ ๋ง๋ค dimension์ 32์ฐจ์์ผ๋ก
random_normal_dimensions = 32
# generator model
# SELU : GAN์ ์ ํฉํ ํ์ฑํ ํจ์๋ก ReLU ๊ณ์ด์ => ์ฒ์ ๋ ๊ฐ์ Dense ๋คํธ์ํฌ์์ ์ฌ์ฉ
# sigmoid : ์ต์ข
Dense ๋คํธ์ํฌ๋ 0๊ณผ 1 ์ฌ์ด์ ํฝ์
๊ฐ์ ์์ฑํ๊ธฐ ์ํด์
generator = keras.models.Sequential([
keras.layers.Dense(64, activation="selu", input_shape=[random_normal_dimensions]),
keras.layers.Dense(128, activation="selu"),
keras.layers.Dense(28 * 28, activation="sigmoid"),
keras.layers.Reshape([28, 28]) # ๋ฐ์ดํฐ ์ธํธ์ ์ฐจ์์ ๋ง๊ฒ reshape
])
# ํ๋ จ๋์ง ์์ generator์ ์ํ ์ถ๋ ฅ => random point ๊ทธ ์์ฒด๋ฅผ ์ถ๋ ฅํ ๊ฒ
# ํ๋ จ ํ : MNIST ๋ฐ์ดํฐ ์ธํธ์ ์ซ์๋ฅผ ๋ฎ์๊ฒ
# batch size = 16๋ก ๋
ธ์ด์ฆ ์์ฑ
test_noise = tf.random.normal([16, random_normal_dimensions])
# feed the batch to the untrained generator
test_image = generator(test_noise)
# visualize sample output
plot_multiple_images(test_image, n_cols=4)
Discriminator
: ์ ๋ ฅ(๊ฐ์ง ๋๋ ์ค์ ) ์ด๋ฏธ์ง๋ฅผ ๊ฐ์ ธ์ ๊ฐ์ง์ธ์ง ์๋์ง๋ฅผ ๊ฒฐ์
# input shape = ํ๋ จ ์ด๋ฏธ์ง์ ๋ชจ์
# => flatten => dense network์ ๊ณต๊ธ => ์ถ๋ ฅ : 0 (๊ฐ์ง)๊ณผ 1 (์ค์ ) ์ฌ์ด์ ๊ฐ
# build the discriminator model
discriminator = keras.models.Sequential([
keras.layers.Flatten(input_shape=[28, 28]),
keras.layers.Dense(128, activation="selu"), # ์ฒ์ ๋ ๊ฐ์ dense network์์ SELU ํ์ฑํ
keras.layers.Dense(64, activation="selu"),
keras.layers.Dense(1, activation="sigmoid") # sigmoid๋ก final network๋ฅผ ํ์ฑํ
])
GAN ๊ตฌ์ถ ๋ฐ ํ๋ จ์ ์ํ ์ค๋น
# Generator, Discriminator ๋ ๋ชจ๋ธ์ ์ถ๊ฐํด GAN ๊ตฌ์ถ
gan = keras.models.Sequential([generator, discriminator])
# Configure Training Parameters
# binary_crossentropy๋ก ์์ค => ์ด์ : ๋ผ๋ฒจ์ด 0 (๊ฐ์ง) ๋๋ 1 (์ค์ )์ด ๋ ๊ฒ
discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")
discriminator.trainable = False
gan.compile(loss="binary_crossentropy", optimizer="rmsprop")
๋ชจ๋ธ ํ๋ จ
- 1 ๋จ๊ณ - ๊ฐ์ง ๋ฐ์ดํฐ์ ์ค์ ๋ฐ์ดํฐ๋ฅผ ๊ตฌ๋ถํ๋๋ก ํ๋ณ์๋ฅผ ํ๋ จ์ํต๋๋ค.
- 2 ๋จ๊ณ - ํ๋ณ์๋ฅผ ์์ด๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋๋ก ์์ฑ์๋ฅผ ํ๋ จํฉ๋๋ค.
๊ฐ epoch๋ง๋ค ์์ฑ์์ ์ํด ์์ฑ๋๋ ๊ฐ์ง ์ด๋ฏธ์ง๋ฅผ ํ์ธํ๊ธฐ ์ํด ์ด๋ฏธ์ง ์ํ ๊ฐค๋ฌ๋ฆฌ๋ฅผ ํ์
def train_gan(gan, dataset, random_normal_dimensions, n_epochs=100):
""" Defines the two-phase training loop of the GAN
Args:
gan -- the GAN model which has the generator and discriminator
dataset -- the training set of real images
random_normal_dimensions -- dimensionality of the input to the generator
n_epochs -- number of epochs
"""
# get the two sub networks from the GAN model
generator, discriminator = gan.layers
# start loop
for epoch in range(n_epochs):
print("Epoch {}/{}".format(epoch + 1, n_epochs))
for real_images in dataset: # ๋ฐ์ดํฐ์
์์ ์ค์ ์ด๋ฏธ์ง๋ฅผ ์ฝ์ด๋ค์
# ํ๋ จ ๋ฐฐ์น์์ ๋ฐฐ์น ํฌ๊ธฐ ์ถ๋ก
batch_size = real_images.shape[0]
# Train the discriminator - PHASE 1
# noise ์์ฑ
noise = tf.random.normal(shape=[batch_size, random_normal_dimensions])
# ๋
ธ์ด์ฆ๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ์ง ์ด๋ฏธ์ง ์์ฑ
fake_images = generator(noise)
# ๊ฐ์ง ์ด๋ฏธ์ง์ ์ค์ ์ด๋ฏธ์ง๋ฅผ ์ฐ๊ฒฐํ์ฌ list ๋ง๋ค๊ธฐ
mixed_images = tf.concat([fake_images, real_images], axis=0)
# discriminator๋ฅผ ์ํ label ์์ฑ (์ง๋ํ์ต์ ํ๊ธฐ ์ํจ)
# 0 for the fake images
# 1 for the real images
discriminator_labels = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
# Ensure that the discriminator is trainable
discriminator.trainable = True
# train_on_batch๋ฅผ ์ฌ์ฉํ์ฌ mixed_images ์ discriminator_labels๋ก ํ๋ณ์๋ฅผ ํ๋ จํฉ๋๋ค.
# 0์ 0์ผ๋ก ๋ง์ถ๊ณ , 1์ 1๋ก ๋ง์ถค => ์์ค์ด ์์
# ์๋ชป ๋ง์ถ๋ฉด backpropagation ๋ ๊ฒ
discriminator.train_on_batch(mixed_images, discriminator_labels)
# discriminator.fit(mixed_images, discriminator_labels) : fit == train_on_batch
# Train the generator - PHASE 2
# GAN์ ๊ณต๊ธํ ๋
ธ์ด์ฆ ์
๋ ฅ ๋ฐฐ์น ์๋กญ๊ฒ ์์ฑ
noise = tf.random.normal(shape=[batch_size, random_normal_dimensions])
# ์์ฑ๋ ๋ชจ๋ ์ด๋ฏธ์ง์ "real" ๋ ์ด๋ธ ์ง์
generator_labels = tf.constant([[1.]] * batch_size)
# Freeze the discriminator : ํ๋ จ๋ ๊ฐ๋ณ์ ๊ณ ์
discriminator.trainable = False
# ๋ ์ด๋ธ์ด ๋ชจ๋ true๋ก ์ค์ ๋ ๋
ธ์ด์ฆ์ ๋ํ GAN ํ๋ จ
# generator์ ์ถ๋ ฅ์ด discriminator์ ์
๋ ฅ์ผ๋ก ๋ค์ด๊ฐ, 1, 0 ์ค ํ๋๋ก ํ๋จ
# discriminator๊ฐ ์์์ 1์ ์ถ๋ ฅ์, generator_labels๊ณผ ๋์ผํ๋ฏ๋ก ์์ค์ด 0 => ์ฑ๊ณต!
gan.train_on_batch(noise, generator_labels) # generator_labels = 1
# gan.fit(noise, generator_labels)
#ํ๋ณ์๋ฅผ ํ๋ จํ๋ ๋ฐ ์ฌ์ฉ๋๋ ๊ฐ์ง ์ด๋ฏธ์ง๋ฅผ ํ๋กํ
ํฉ๋๋ค.
plot_multiple_images(fake_images, 8)
plt.show()
# ํ๋ จ
train_gan(gan, dataset, random_normal_dimensions, n_epochs=30)
