๋™์•„๋ฆฌ,ํ•™ํšŒ/GDGoC

[AI ์Šคํ„ฐ๋””] Section 8 : ์ „์ดํ•™์Šต

egahyun 2024. 12. 27. 02:11

์ „์ดํ•™์Šต์ด๋ž€?

๋ชจ๋ธ ์ž‘๋™ ํ๋ฆ„

  1. ๊ธฐ์กด์˜ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ : 1000๊ฐ€์ง€๋ฅผ ๋ถ„๋ฅ˜ํ•  ์ˆ˜ ์žˆ๋Š” ํฐ ๋ชจ๋ธ
  2. weight transfer
    1. ์‚ฌ์šฉ ์˜ˆ์‹œ : ๊ธฐ์กด ํฐ ๋ชจ๋ธ์ด ์žˆ๊ณ , 2๊ฐ€์ง€๋ฅผ ๋ถ„๋ฅ˜ํ•  ์ˆ˜ ์žˆ๋Š” ๋ชจ๋ธ์„ ํ•˜๊ณ ์‹ถ์„ ๋•Œ
    2. ํšจ๊ณผ : ์†Œ๊ทœ๋ชจ์˜ ๋ฐ์ดํ„ฐ ์…‹์œผ๋กœ ํ›ˆ๋ จ์—๋„ ์ข‹์€ ์„ฑ๋Šฅ์„ ๋‚ผ ์ˆ˜ ์žˆ์Œ
    3. ํ›ˆ๋ จ๋œ ์ปจ๋ณผ๋ฃจ์…˜ ๋ ˆ์ด์–ด : ํ›ˆ๋ จ๋œ ํ•„ํ„ฐ๋“ค์ด ๋‚ด์žฅ๋˜์–ด์žˆ์Œ
    4. ⇒ ์ž‘์€ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ๋„ ๋ถ„๋ฅ˜ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•จ
  3. : ๊ธฐ์กด ๋ชจ๋ธ์˜ ์ปจ๋ณผ๋ฃจ์…˜ ๋ ˆ์ด์–ด๋ฅผ ๊ฐ€์ ธ์™€ ๊ฐ€์ค‘์น˜๋ฅผ ์ „์œ„์‹œ์ผœ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ

ํ•™์Šต ์ „๋žต

  1. CNN layer = ์œ ์ง€ / ์ถ”๊ฐ€ํ•œ ์™„์ „์—ฐ๊ฒฐ์ธต (Dense Layer) = ์ƒˆ๋กญ๊ฒŒ ํ•™์Šต
    1. ์‚ฌ์šฉ ์˜ˆ์‹œ : ๋ฐ์ดํ„ฐ๊ฐ€ ๊ต‰์žฅํžˆ ์ž‘์€ ๊ฒฝ์šฐ
  2. ์ „์ฒด ๋ ˆ์ด์–ด = ๋งค์šฐ ์ž‘์€ learning rate๋กœ ์žฌํ•™์Šต
    1. ์‚ฌ์šฉ ์˜ˆ์‹œ : ๋ฐ์ดํ„ฐ๋ฅผ ์ „์ฒด ํ›ˆ๋ จ ์‹œํ‚ฌ ์ •๋„๋กœ ๋งŽ์ด ์žˆ๋Š” ๊ฒฝ์šฐ
    2. ์ž‘์€ ํ•™์Šต๋ฅ ์„ ์‚ฌ์šฉํ•ด์•ผํ•จ ⇒ ๊ฐ€๊ธ‰์  ๊ฐ€์ค‘์น˜๋ฅผ ์ž˜ ์œ ์ง€ํ•˜๋ฉฐ, ์กฐ๊ธˆ ํŠœ๋‹ํ•˜๋„๋ก !(ํฐ ํ•™์Šต๋ฅ  ์‚ฌ์šฉ์‹œ, ๊ฐ€์ค‘์น˜๊ฐ€ ๋‹ค ํํŠธ๋Ÿฌ์ง)
    3. (์ด๋ฏธ ํ›ˆ๋ จ๋œ convolutional layer๋ฅผ ์žฌํ›ˆ๋ จํ•˜๊ธฐ ๋•Œ๋ฌธ)

๊ณ ๋ ค ์‚ฌํ•ญ

  1. ๋ชฉ์ ์— ๋งž๋Š” ๋ฐ์ดํ„ฐ ์…‹ ์„ ํƒ
    1. ์žˆ๋Š” ๋ฐ์ดํ„ฐ์…‹ : ๋ถ„๋ฅ˜๊ฐ€ ์ž˜๋  ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์Œ (ex : Cat & Dog ๊ตฌ๋ถ„ → ImageNet ์— ํฌํ•จ)
    2. ์—†๋Š” ๋ฐ์ดํ„ฐ์…‹ : ๋ถ„๋ฅ˜๊ฐ€ ์•ˆ๋จ (ex : Cancer cell ๊ตฌ๋ถ„ → ImageNet ์— ์—†์Œ ⇒ ์•ผ๊ตฌ๊ณต ๋“ฑ์œผ๋กœ ๋ถ„๋ฅ˜)
  2. : ImageNet์— ์žˆ๋Š” ๋ฐ์ดํ„ฐ ์…‹์ธ์ง€ ํ™•์ธ ํ•„์š”
  3. ๋ณด์œ  ๋ฐ์ดํ„ฐ์˜ Volume ๊ณ ๋ ค
    1. ๊ตฌ์กฐ๋งŒ ๊ฐ€์ ธ์˜ค๊ณ , ๋ชจ๋“  weight ์ƒˆ๋กœ์ด training (Large Data ๋ณด์œ )
    2. Weight ์˜ ์ผ๋ถ€๋งŒ training
    3. ๋งˆ์ง€๋ง‰ layer ๋งŒ Fine-tuning (Small Data ๋ณด์œ ) : ํ›ˆ๋ จ์ž˜ ๋˜์–ด ์žˆ๋˜ ๋ถ€๋ถ„์€ ๊ฑด๋“ค์ง€ ์•Š์Œ

Tensorflow Hub

https://www.tensorflow.org/hub?hl=ko : ํŒŒ์ด์ฌ๋งŒ ์‚ฌ์šฉํ• ์ค„ ์•Œ๋ฉด ๋ชจ๋ธ์„ ๋งŒ๋“ค ์ˆ˜ ์žˆ์Œ

ImageDataGenerator

: ์ด๋ฏธ์ง€ ๋ ˆ์ด๋ธ”๋ง์ด๋˜๋„๋ก ๋ฐ˜ํ™˜ํ•ด์ฃผ๋Š” ํ•จ์ˆ˜

  1. methods : .flow_from_directory
    • ๋Œ€์šฉ๋Ÿ‰ data ๋ฅผ directory ์—์„œ ์ง์ ‘ ๋กœ๋“œ
    • directory ๊ตฌ์กฐ์— ์˜ํ•ด ์ž๋™์œผ๋กœ label ์ธ์‹

   2. flow_from_directory ์‚ฌ์šฉ๋ฒ•

from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Instance ์ƒ์„ฑ
train_data_gen = ImageDataGenerator(rescale=1/255.)

# flow_from_directory method ํ˜ธ์ถœ :
train_generator = train_data_gen.flow_from_directory(
	train_dir,
	target_size(150, 150),  # image size ํ†ต์ผ
	batch_size=20,
	class_mode='binary’ or 'categorical)

 

3. data augmentation (๋ฐ์ดํ„ฐ ์ฆ๊ฐ•)

: ๋‹ค๋ฅธ์ด๋ฏธ์ง€ ์ฒ˜๋Ÿผ ๋ณด์ด๋„๋ก ํ•˜์—ฌ ๋ถ€์กฑํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ๋ณด์ถฉํ•ด์ฃผ๋Š” ๊ธฐ๋Šฅ

 → ex) ์ด๋ฏธ์ง€๋ฅผ ์ฐŒ๊ทธ๋ŸฌํŠธ๋ฆผ / ์ขŒ์šฐ ๋ฐ˜์ „

 → flow_from_directory ๊ฐ€ ๋ถ€์ˆ˜์ ์œผ๋กœ ์ œ๊ณตํ•˜๋Š” ๊ธฐ๋Šฅ

 


์‹ค์Šต - Tensorflow Hub ๋ชจ๋ธ์„ ์ด์šฉํ•œ ์ „์ด ํ•™์Šต

 

๋ฌธ์ œ ์ •์˜

: pre-trained model (MobileNet_V2) ์„ feature extractor๋กœ ์ด์šฉํ•˜์—ฌ ๊ฝƒ image ์— ํŠนํ™”๋œ image ๋ถ„๋ฅ˜ model

๋ชจ๋ธ ๊ตฌ์„ฑ : MobileNet

  • ImageNet ์˜ ์ˆ˜๋ฐฑ๋งŒ์žฅ ์ด๋ฏธ์ง€ ๋ฅผ ์ด์šฉํ•˜์—ฌ ํ›ˆ๋ จ๋จ
  • 1000 ๊ฐœ์˜ class ๋กœ ์ด๋ฏธ์ง€ ๊ตฌ๋ถ„
  • class์ค‘ ํ™•๋ฅ ์ด ๋†’์€ 5๊ฐœ๋ฅผ ๋ฐ˜ํ™˜
model = tf.keras.Sequntial([
	**MobileNet_feature_extractor_layer**
	tf.keras.layers.Dense(flowers_data.num_classes, activation='softmax')
])    

๋ชจ๋ธ ๊ตฌ์„ฑ : ํŒŒ์ธํŠœ๋‹ํ•˜์ง€ ์•Š์€ ๋ชจ๋ธ

import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras.applications.mobilenet import decode_predictions # ๋ช‡ ๋ฒˆ์งธ์— ์–ด๋–ค ์‚ฌ์ง„์ธ์ง€ ํ™•์ธ ํ•  ์ˆ˜ ์žˆ์Œ

# Fine Tuning ์—†์ด ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด Full Model download
Trained_MobileNet_url = "<https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2>"

Trained_Mobilenet = tf.keras.Sequential([
                    hub.KerasLayer(Trained_MobileNet_url , input_shape=(224, 224,3))]) # ์‚ฌ์ „ ํ›ˆ๋ จ์‹œ, ์‚ฌ์šฉํ–ˆ๋˜ ์ž…๋ ฅ ์‚ฌ์ด์ฆˆ๋ฅผ ๊ทธ๋Œ€๋กœ ํ•ด์•ผํ•จ

Trained_Mobilenet.input, Trained_Mobilenet.output
# (,
# ) -> 1001 ๊ฐœ์˜ ํ™•๋ฅ ๋ถ„ํฌ ์ƒ์„ฑ
# ์œ„์˜ ์ฝ”๋“œ๊ฐ€ ์˜ค๋ฅ˜๋‚˜๋Š” ๊ฒฝ์šฐ ํ•ด๊ฒฐ ๋ฐฉ๋ฒ• 1
!pip install tf_keras

import tf_keras as tfk

Trained_MobileNet_url = "<https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2>"

Trained_MobileNet = tfk.Sequential([
    hub.KerasLayer(Trained_MobileNet_url, input_shape=(224, 224, 3))
])

Trained_MobileNet.input, Trained_MobileNet.output

# ํ•ด๊ฒฐ ๋ฐฉ๋ฒ• 2 : ์•„๋ž˜์ฒ˜๋Ÿผ ๋‹ค์šด๊ทธ๋ ˆ์ด๋“œํ›„, ์›๋ž˜ ์ฝ”๋“œ ์‹คํ–‰
!pip install tensorflow==2.13.0 !pip install tensorflow_hub==0.14.0

ํŒŒ์ธ ํŠœ๋‹ํ•˜์ง€ ์•Š์€ ๋ชจ๋ธ๋กœ ๋ถ„๋ฅ˜

from PIL import Image       # ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
from urllib import request  # ์ธํ„ฐ๋„ท์—์„œ ๊ฐ€์ ธ์˜ค๊ธฐ ํ•จ
from io import BytesIO      # ์ธํ„ฐ๋„ท์—์„œ ์‚ฌ์ง„ ์ˆ˜์ง‘์‹œ ์•„์Šคํ‚ค์ฝ”๋“œ๋กœ ์ˆ˜์ง‘๋˜๋ฏ€๋กœ ์ด๋ฏธ์ง€ํ˜•ํƒœ๋กœ ๋ฐ”๊พธ๊ธฐ ์œ„ํ•ด

# ์ˆ˜์ง‘ํ•  ์ด๋ฏธ์ง€ url๋กœ ๊ฐ€์ ธ์˜ด
url = "<https://github.com/ironmanciti/MachineLearningBasic/blob/master/datasets/TransferLearningData/watch.jpg?raw=true>"
res = request.urlopen(url).read()
Sample_Image = Image.open(BytesIO(res)).resize((224, 224)) # ๋ชจ๋ธ ์ž…๋ ฅ ์‚ฌ์ด์ฆˆ์— ๋งž์ถฐ์„œ resize (tuple๋กœ)

# numpy array๋กœ ์ƒ˜ํ”Œ์ด๋ฏธ์ง€๋ฅผ ๋„ฃ์–ด์„œ ํ”„๋ฆฌํ”„๋กœ์„ธ์‹ฑ๋œ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ ์ถœ๋ ฅ
# ๋ชจ๋ธ์ด ์ž…๋ ฅ๋ฐ์ดํ„ฐ๋ฅผ ์–ด๋–ป๊ฒŒ ์ „์ฒ˜๋ฆฌ ํ–ˆ๋Š”์ง€ ๋ชจ๋ฅด๊ธฐ ๋•Œ๋ฌธ์— preprocess_input์„ ์‚ฌ์šฉํ•ด ์ „์ฒ˜๋ฆฌ
x = tf.keras.applications.mobilenet.preprocess_input(np.array(Sample_Image))
x.shape # (224, 224, 3)

# ํด๋ž˜์Šค ์˜ˆ์ธก : 1001๊ฐ€์ง€ ํ™•๋ฅ  ๋ถ„ํฌ (1000๊ฐœ์˜ ํด๋ž˜์Šค + ๋ฐฐ๊ฒฝ 1๊ฐœ)
predicted_class = Trained_Mobilenet.predict(np.expand_dims(x, axis = 0)) # ํ•œ๊ฑด์˜ ์ด๋ฏธ์ง€์ง€๋งŒ ๋ฐฐ์น˜์ธ๊ฒƒ ์ฒ˜๋Ÿผ์คŒ 

# ๋ถ„๋ฅ˜๋œ ํด๋ž˜์Šค์˜ ์ธ๋ฑ์Šค๊ฐ€ ์ถ”์ถœ -> ex) 827 : 827๋ฒˆ์งธ ํด๋ž˜์Šค๋กœ ๋ถ„๋ฅ˜๊ฒƒ
predicted_class.argmax(axis=-1)

# ํ•ด๋‹น ์‚ฌ์ง„์˜ ํด๋ž˜์Šค๋ณ„ ์˜ˆ์ธก ํ™•๋ฅ  top 5 ํ™•์ธ 
decode_predictions(predicted_class[:, 1:])  # ์ฒซ๋ฒˆ์งธ label์€ background

[[('n04328186', 'stopwatch', 9.666367),
  ('n02708093', 'analog_clock', 8.007808),
  ('n03706229', 'magnetic_compass', 6.8384614),
  ('n04548280', 'wall_clock', 6.563991),
  ('n03197337', 'digital_watch', 4.9182053)]]

# ์ž…๋ ฅ ์‚ฌ์ง„๊ณผ ๋ถ„๋ฅ˜๋œ ๊ฒƒ์œผ๋กœ ์ œ๋ชฉ์œผ๋กœ ํ•˜์—ฌ ์‹œ๊ฐํ™”
plt.imshow(Sample_Image)
predicted_class = imagenet_labels[np.argmax(predicted_class)]
plt.title("Predicted Class is: " + predicted_class.title())

# 1000๊ฐ€์ง€ ๋ ˆ์ด๋ธ”์ด ๋ญ๊ฐ€ ์žˆ๋Š”์ง€ ํ™•์ธ
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt',
                '<https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt>')
imagenet_labels = np.array(open(labels_path).read().splitlines())

print(imagenet_labels[:10])

๊ฝƒ ์‚ฌ์ง„ Batch Image ์— ๋Œ€ํ•œ MobileNet ํ‰๊ฐ€ - ํŒŒ์ธ ํŠœ๋‹ํ•˜์ง€ ์•Š์€ ๋ชจ๋ธ

flower data ๋Š” 5 ๊ฐœ์˜ class ๋กœ ๊ตฌ์„ฑ

# Specify path of the flowers dataset : ๋ฐ์ดํ„ฐ ๋‹ค์šด๋กœ๋“œ
flowers_data_path = tf.keras.utils.get_file(
  'flower_photos','<https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz>', 
    untar=True)

# ์ด๋ฏธ์ง€ ์ž๋™ ๋ ˆ์ด๋ธ”๋ง : Found 3670 images belonging to 5 classes
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(
            preprocessing_function=tf.keras.applications.mobilenet.preprocess_input) # ์ „์ฒ˜๋ฆฌ ๋ชจ๋“ˆ

flowers_data = image_generator.**flow_from_directory**(flowers_data_path, 
                    target_size=(224, 224), batch_size = 64, shuffle = True)

# input_batch : ์ด๋ฏธ์ง€ ์ž์ฒด
# label_batch : ์ด๋ฏธ์ง€๊ฐ€ ์กด์žฌํ•˜๋Š” ํด๋” ์ด๋ฆ„
input_batch, label_batch = next(flowers_data) # ์ œ๋„ˆ๋ ˆ์ดํ„ฐ ํ•จ์ˆ˜์ด๋ฏ€๋กœ next๋ฅผ ์ด์šฉํ•ด ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ด

print("Image batch shape: ", input_batch.shape)    # (64, 224, 224, 3)
print("Label batch shape: ", label_batch.shape)    # (64, 5) : ์›ํ•ซ์ธ์ฝ”๋”ฉ ๋˜์–ด์žˆ์Œ
print("Label class ์ˆ˜: ", flowers_data.num_classes) # 5
print("Class Index : ", flowers_data.class_indices) # {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}

# Key : label ์ด๋ฆ„,  Value : index => Key : index,  Value : label ์ด๋ฆ„
# ์–ด๋–ค ์ธ๋ฑ์Šค๊ฐ€ ์–ด๋–ค ๋ ˆ์ด๋ธ” ๋ช…์ธ์ง€ ํ™•์ธ ๊ฐ€๋Šฅํ•˜๋ก
class_names = {v:k for k,v in flowers_data.class_indices.items()}

# ์ž„์˜์˜ ๊ฝƒ ์ด๋ฏธ์ง€ 1 ๊ฐœ๋ฅผ ์„ ํƒํ•˜์—ฌ prediction ๋น„๊ต : ์„ฑ๋Šฅ์ด ์•ˆ ์ข‹๋‹ค.
prediction = Trained_Mobilenet.predict(input_batch[2:3])
decode_predictions(prediction[:, 1:])  

[[('n03930313', 'picket_fence', 5.1765366),
  ('n03944341', 'pinwheel', 4.367714),
  ('n03598930', 'jigsaw_puzzle', 3.9853535),
  ('n03447721', 'gong', 3.9513822),
  ('n09256479', 'coral_reef', 3.7957406)]]

๋ถ„๋ฅ˜๋œ ์ด๋ฏธ์ง€ ํ™•์ธ ์‹œ๊ฐํ™”

# 10 ๊ฐœ image ์‹œ๊ฐํ™”
plt.figure(figsize=(16, 8))
for i in range(10):
    plt.subplot(1, 10, i+1)
    img = ((input_batch[i]+1)*127.5).astype(np.uint8)
    idx  = np.argmax(label_batch[i])
    plt.imshow(img)
    plt.title(class_names[idx])
    plt.axis('off')

์ „์ดํ•™์Šต ๋ชจ๋ธ์„ Flower ๋ถ„๋ฅ˜์— ์ ํ•ฉํ•œ ๋ชจ๋ธ๋กœ ์žฌํ›ˆ๋ จ

⇒ Fine Tuning ์„ ์œ„ํ•ด head ๊ฐ€ ์ œ๊ฑฐ๋œ model ์„ download

๋นจ๊ฐ„์ƒ‰ ๋ฐ•์Šค ์•ˆ์˜ ๊ฒƒ๋งŒ ์‚ฌ์šฉ

# ํƒ‘ ๋ ˆ์ด์–ด๋ฅผ ์ œ๊ฑฐํ•œ ๋ชจ๋ธ์˜ url
extractor_url = "<https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/2>"

# ํŠน์„ฑ ํ”ผ์ฒ˜๊ฐ’
extractor_layer = hub.KerasLayer(extractor_url, input_shape=(224, 224, 3))
feature_batch = extractor_layer(input_batch)

# MobileNet ์˜ pre-trained weight ๋Š” update x 
# Top layer ์— Dense layer ์ถ”๊ฐ€
# CNN layer = ์œ ์ง€ / ์ถ”๊ฐ€ํ•œ ์™„์ „์—ฐ๊ฒฐ์ธต (Dense Layer) = ์ƒˆ๋กญ๊ฒŒ ํ•™์Šต ํ•˜๋Š” ๋žต
extractor_layer.trainable = False

# Build a model with two pieces:
#    (1)  MobileNet Feature Extractor 
#    (2)  Dense Network (classifier) added at the end 
model = tf.keras.Sequential([
  extractor_layer,
  tf.keras.layers.Dense(flowers_data.num_classes, activation='softmax')
])

# output shape ์ด ์ •ํ™•ํ•œ์ง€ training ์ „์— ์‚ฌ์ „ check
# (,
# )
model.input, model.output

# ๋ชจ๋ธ ์ปดํŒŒ์ผ ๋ฐ ํ›ˆ๋ จ
# ๋‹ค์ค‘๋ถ„๋ฅ˜์ด๋ฏ€๋กœ categorical_crossentropy
# 64๊ฐœ์”ฉ flower data๊ฐ€ ๋ถˆ๋Ÿฌ๋“ค์—ฌ์˜ด
model.compile(optimizer=tf.keras.optimizers.Adam(), 
              loss='categorical_crossentropy', metrics=['accuracy'])

history = model.fit(flowers_data, epochs=30)

Flower ๋ถ„๋ฅ˜ ์ „๋ฌธ์œผ๋กœ Fine Tuning ๋œ MODEL ํ‰๊ฐ€

# ํ™•๋ฅ  ๋ถ„ํฌ๋กœ ๋‚˜์˜ด
y_pred = model.predict(input_batch)
# ๋ถ„๋ฅ˜๋œ ๊ฒƒ์˜ ์ธ๋ฑ์Šค๋ฅผ ๊ฐ€์ง€๋„๋ก
y_pred = np.argmax(y_pred, axis=-1)
# ์ •๋‹ต ๋ฐ์ดํ„ฐ์˜ ์ธ๋ฑ์Šค ํ™•์ธ
y_true = np.argmax(label_batch, axis=-1)

# ์ •ํ™•๋„ : 100 %
f"{sum(y_pred == y_true) / len(y_true) * 100:.2f} %"

# ์˜ˆ์ธก ์‹œ๊ฐํ™”
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)

for i in range(64):
  plt.subplot(8, 8, i+1)
  img = ((input_batch[i]+1)*127.5).astype(np.uint8)
  plt.imshow(img)
  color = "green" if y_pred[i] == y_true[i] else "red"
  plt.title(class_names[y_pred[i]], color=color)
  plt.axis('off')