ถ่ายทอดการเรียนรู้และการปรับแต่ง

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

ในบทช่วยสอนนี้ คุณจะได้เรียนรู้วิธีจำแนกรูปภาพของแมวและสุนัขโดยใช้การเรียนรู้แบบโอนย้ายจากเครือข่ายที่ผ่านการฝึกอบรมมาแล้ว

โมเดลที่ได้รับการฝึกอบรมล่วงหน้าคือเครือข่ายที่บันทึกไว้ซึ่งได้รับการฝึกอบรมมาก่อนหน้านี้ในชุดข้อมูลขนาดใหญ่ โดยทั่วไปแล้วจะเป็นการจัดประเภทรูปภาพขนาดใหญ่ คุณใช้โมเดลที่ฝึกไว้ล่วงหน้าตามที่เป็นอยู่หรือใช้การเรียนรู้แบบถ่ายโอนเพื่อปรับแต่งโมเดลนี้ให้เข้ากับงานที่กำหนด

สัญชาตญาณเบื้องหลังการเรียนรู้การถ่ายโอนสำหรับการจัดประเภทรูปภาพคือ ถ้าแบบจำลองได้รับการฝึกอบรมในชุดข้อมูลที่มีขนาดใหญ่และทั่วไปเพียงพอ โมเดลนี้จะทำหน้าที่เป็นแบบจำลองทั่วไปของโลกภาพได้อย่างมีประสิทธิภาพ จากนั้น คุณสามารถใช้ประโยชน์จากแมปคุณลักษณะที่เรียนรู้เหล่านี้ได้โดยไม่ต้องเริ่มต้นใหม่โดยการฝึกโมเดลขนาดใหญ่บนชุดข้อมูลขนาดใหญ่

ในสมุดบันทึกนี้ คุณจะลองใช้สองวิธีในการปรับแต่งแบบจำลองล่วงหน้า:

  1. การแยกคุณลักษณะ: ใช้การแทนค่าที่เรียนรู้โดยเครือข่ายก่อนหน้านี้เพื่อดึงคุณลักษณะที่มีความหมายจากตัวอย่างใหม่ คุณเพียงแค่เพิ่มตัวแยกประเภทใหม่ ซึ่งจะได้รับการฝึกตั้งแต่เริ่มต้น ที่ด้านบนของโมเดลที่ได้รับการฝึกมาล่วงหน้า เพื่อให้คุณสามารถปรับใช้ฟีเจอร์แมปที่เรียนรู้ก่อนหน้านี้สำหรับชุดข้อมูลได้

    คุณไม่จำเป็นต้อง (ซ้ำ) ฝึกโมเดลทั้งหมด เครือข่าย Convolutional พื้นฐานมีคุณลักษณะที่เป็นประโยชน์โดยทั่วไปสำหรับการจัดประเภทรูปภาพอยู่แล้ว อย่างไรก็ตาม ส่วนการจำแนกขั้นสุดท้ายของแบบจำลองสำเร็จรูปนั้นจำเพาะกับงานการจำแนกประเภทดั้งเดิม และต่อมาก็เจาะจงกับชุดของชั้นเรียนที่แบบจำลองได้รับการฝึกอบรม

  2. การปรับละเอียด: ยกเลิกการตรึงเลเยอร์บนสุดบางชั้นของฐานโมเดลที่ตรึงไว้ และร่วมกันฝึกทั้งเลเยอร์ลักษณนามที่เพิ่มใหม่และเลเยอร์สุดท้ายของโมเดลฐาน ซึ่งช่วยให้เราสามารถ "ปรับแต่ง" การนำเสนอคุณลักษณะที่มีลำดับสูงกว่าในแบบจำลองพื้นฐาน เพื่อให้มีความเกี่ยวข้องกับงานเฉพาะมากขึ้น

คุณจะทำตามเวิร์กโฟลว์แมชชีนเลิร์นนิงทั่วไป

  1. ตรวจสอบและทำความเข้าใจข้อมูล
  2. สร้างไพพ์ไลน์อินพุต ในกรณีนี้โดยใช้ Keras ImageDataGenerator
  3. เขียนแบบ
    • โหลดในแบบจำลองพื้นฐานที่ฝึกไว้ล่วงหน้า (และตุ้มน้ำหนักที่ฝึกไว้ล่วงหน้า)
    • วางชั้นการจำแนกไว้ด้านบน
  4. ฝึกโมเดล
  5. ประเมินแบบจำลอง
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

การประมวลผลข้อมูลล่วงหน้า

ดาวน์โหลดข้อมูล

ในบทช่วยสอนนี้ คุณจะใช้ชุดข้อมูลที่มีภาพแมวและสุนัขหลายพันภาพ ดาวน์โหลดและแตกไฟล์ zip ที่มีรูปภาพ จากนั้นสร้าง tf.data.Dataset สำหรับการฝึกอบรมและการตรวจสอบโดยใช้ยูทิลิตี้ tf.keras.utils.image_dataset_from_directory คุณสามารถเรียนรู้เพิ่มเติมเกี่ยวกับการโหลดรูปภาพในบทช่วย สอน นี้

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

train_dataset = tf.keras.utils.image_dataset_from_directory(train_dir,
                                                            shuffle=True,
                                                            batch_size=BATCH_SIZE,
                                                            image_size=IMG_SIZE)
Downloading data from https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
68608000/68606236 [==============================] - 1s 0us/step
68616192/68606236 [==============================] - 1s 0us/step
Found 2000 files belonging to 2 classes.
validation_dataset = tf.keras.utils.image_dataset_from_directory(validation_dir,
                                                                 shuffle=True,
                                                                 batch_size=BATCH_SIZE,
                                                                 image_size=IMG_SIZE)
Found 1000 files belonging to 2 classes.

แสดงภาพเก้าภาพแรกและป้ายกำกับจากชุดการฝึก:

class_names = train_dataset.class_names

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

png

เนื่องจากชุดข้อมูลดั้งเดิมไม่มีชุดทดสอบ คุณจะต้องสร้างชุดทดสอบขึ้นมา ในการดำเนินการดังกล่าว ให้กำหนดจำนวนชุดข้อมูลที่มีอยู่ในชุดการตรวจสอบความถูกต้องโดยใช้ tf.data.experimental.cardinality จากนั้นย้าย 20% ของข้อมูลทั้งหมดไปยังชุดทดสอบ

val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))
Number of validation batches: 26
Number of test batches: 6

กำหนดค่าชุดข้อมูลสำหรับประสิทธิภาพ

ใช้การดึงข้อมูลล่วงหน้าแบบบัฟเฟอร์เพื่อโหลดอิมเมจจากดิสก์โดยที่ I/O จะไม่ถูกบล็อก หากต้องการเรียนรู้เพิ่มเติมเกี่ยวกับวิธีการนี้ โปรดดูคู่มือ ประสิทธิภาพข้อมูล

AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

ใช้การเสริมข้อมูล

เมื่อคุณไม่มีชุดข้อมูลรูปภาพขนาดใหญ่ เป็นการดีที่จะแนะนำความหลากหลายของตัวอย่างโดยใช้การแปลงแบบสุ่มแต่เหมือนจริงกับรูปภาพการฝึก เช่น การหมุนและการพลิกแนวนอน ซึ่งช่วยให้แบบจำลองได้มองเห็นข้อมูลการฝึกอบรมในด้านต่างๆ และลด การ overfitting คุณสามารถเรียนรู้เพิ่มเติมเกี่ยวกับการเสริมข้อมูลในบทช่วย สอน นี้

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip('horizontal'),
  tf.keras.layers.RandomRotation(0.2),
])

ลองใช้เลเยอร์เหล่านี้ซ้ำ ๆ กับรูปภาพเดียวกันและดูผลลัพธ์

for image, _ in train_dataset.take(1):
  plt.figure(figsize=(10, 10))
  first_image = image[0]
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
    plt.imshow(augmented_image[0] / 255)
    plt.axis('off')

png

ปรับขนาดค่าพิกเซล

อีกสักครู่ คุณจะดาวน์โหลด tf.keras.applications.MobileNetV2 เพื่อใช้เป็นโมเดลพื้นฐานของคุณ โมเดลนี้คาดหวังค่าพิกเซลใน [-1, 1] แต่ ณ จุดนี้ ค่าพิกเซลในภาพของคุณอยู่ใน [0, 255] หากต้องการปรับขนาดใหม่ ให้ใช้วิธีการประมวลผลล่วงหน้าที่มาพร้อมกับโมเดล

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
rescale = tf.keras.layers.Rescaling(1./127.5, offset=-1)

สร้างโมเดลพื้นฐานจากคอนเน็ตต์ที่ผ่านการฝึกอบรมมาแล้ว

คุณจะต้องสร้างโมเดลพื้นฐานจากโมเดล MobileNet V2 ที่พัฒนาโดย Google นี่เป็นการฝึกอบรมล่วงหน้าบนชุดข้อมูล ImageNet ซึ่งเป็นชุดข้อมูลขนาดใหญ่ที่ประกอบด้วยอิมเมจ 1.4M และ 1,000 คลาส ImageNet เป็นชุดข้อมูลการฝึกอบรมการวิจัยที่มีหมวดหมู่หลากหลาย เช่น jackfruit และ syringe ฐานความรู้นี้จะช่วยเราจำแนกแมวและสุนัขจากชุดข้อมูลเฉพาะของเรา

ขั้นแรก คุณต้องเลือกเลเยอร์ของ MobileNet V2 ที่คุณจะใช้สำหรับการแยกคุณลักษณะ เลเยอร์การจัดหมวดหมู่สุดท้าย (ที่ "บนสุด" เนื่องจากไดอะแกรมของโมเดลการเรียนรู้ของเครื่องส่วนใหญ่เริ่มจากล่างขึ้นบน) ไม่ค่อยมีประโยชน์ แต่คุณจะต้องปฏิบัติตามแนวทางปฏิบัติทั่วไปโดยอาศัยเลเยอร์สุดท้ายก่อนที่จะดำเนินการให้เรียบ ชั้นนี้เรียกว่า "ชั้นคอขวด" คุณสมบัติของเลเยอร์คอขวดยังคงมีลักษณะทั่วไปมากกว่าเมื่อเปรียบเทียบกับเลเยอร์สุดท้าย/บนสุด

ขั้นแรก สร้างตัวอย่างโมเดล MobileNet V2 ที่โหลดไว้ล่วงหน้าด้วยตุ้มน้ำหนักที่ฝึกบน ImageNet โดยการระบุอาร์กิวเมนต์ include_top=False คุณจะโหลดเครือข่ายที่ไม่มีเลเยอร์การจัดหมวดหมู่ที่ด้านบน ซึ่งเหมาะสำหรับการแยกคุณลักษณะ

# Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step
9420800/9406464 [==============================] - 0s 0us/step

ตัวแยกคุณลักษณะนี้จะแปลงรูปภาพ 160x160x3 แต่ละภาพเป็นบล็อกคุณลักษณะขนาด 5x5x1280 มาดูกันว่ามันทำอะไรกับชุดรูปภาพตัวอย่าง:

image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)
(32, 5, 5, 1280)

การแยกคุณสมบัติ

ในขั้นตอนนี้ คุณจะตรึงฐานที่บิดเบี้ยวที่สร้างขึ้นจากขั้นตอนก่อนหน้าและเพื่อใช้เป็นตัวแยกคุณลักษณะ นอกจากนี้ คุณเพิ่มตัวแยกประเภทและฝึกตัวแยกประเภทระดับบนสุด

ตรึงฐานบิด

สิ่งสำคัญคือต้องตรึงฐานที่บิดเบี้ยวก่อนที่คุณจะคอมไพล์และฝึกโมเดล การแช่แข็ง (โดยการตั้งค่า layer.trainable = False) ป้องกันไม่ให้มีการอัพเดทน้ำหนักในเลเยอร์ที่กำหนดระหว่างการฝึก MobileNet V2 มีหลายเลเยอร์ ดังนั้นการตั้งค่าแฟ trainable กที่ฝึกได้ของโมเดลทั้งหมดเป็น "เท็จ" จะทำให้เลเยอร์ทั้งหมดหยุดนิ่ง

base_model.trainable = False

หมายเหตุสำคัญเกี่ยวกับเลเยอร์ BatchNormalization

หลายรุ่นมีเลเยอร์ tf.keras.layers.BatchNormalization เลเยอร์นี้เป็นกรณีพิเศษ และควรใช้ความระมัดระวังในบริบทของการปรับแต่ง ดังที่แสดงในบทช่วยสอนนี้ในภายหลัง

เมื่อคุณตั้งค่า layer.trainable = False เลเยอร์ BatchNormalization จะทำงานในโหมดอนุมาน และจะไม่อัปเดตค่าเฉลี่ยและความแปรปรวนของสถิติ

เมื่อคุณเลิกตรึงโมเดลที่มีเลเยอร์ BatchNormalization เพื่อทำการปรับแต่งอย่างละเอียด คุณควรเก็บเลเยอร์ BatchNormalization ในโหมดอนุมานโดยผ่าน training = False เมื่อเรียกใช้โมเดลพื้นฐาน มิฉะนั้น การอัปเดตที่ใช้กับตุ้มน้ำหนักที่ไม่สามารถฝึกได้จะทำลายสิ่งที่โมเดลได้เรียนรู้

สำหรับรายละเอียดเพิ่มเติม โปรดดูที่ คู่มือการโอนย้ายการเรียนรู้

# Let's take a look at the base model architecture
base_model.summary()
Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 160, 160, 3  0           []                               
                                )]                                                                
                                                                                                  
 Conv1 (Conv2D)                 (None, 80, 80, 32)   864         ['input_1[0][0]']                
                                                                                                  
 bn_Conv1 (BatchNormalization)  (None, 80, 80, 32)   128         ['Conv1[0][0]']                  
                                                                                                  
 Conv1_relu (ReLU)              (None, 80, 80, 32)   0           ['bn_Conv1[0][0]']               
                                                                                                  
 expanded_conv_depthwise (Depth  (None, 80, 80, 32)  288         ['Conv1_relu[0][0]']             
 wiseConv2D)                                                                                      
                                                                                                  
 expanded_conv_depthwise_BN (Ba  (None, 80, 80, 32)  128         ['expanded_conv_depthwise[0][0]']
 tchNormalization)                                                                                
                                                                                                  
 expanded_conv_depthwise_relu (  (None, 80, 80, 32)  0           ['expanded_conv_depthwise_BN[0][0
 ReLU)                                                           ]']                              
                                                                                                  
 expanded_conv_project (Conv2D)  (None, 80, 80, 16)  512         ['expanded_conv_depthwise_relu[0]
                                                                 [0]']                            
                                                                                                  
 expanded_conv_project_BN (Batc  (None, 80, 80, 16)  64          ['expanded_conv_project[0][0]']  
 hNormalization)                                                                                  
                                                                                                  
 block_1_expand (Conv2D)        (None, 80, 80, 96)   1536        ['expanded_conv_project_BN[0][0]'
                                                                 ]                                
                                                                                                  
 block_1_expand_BN (BatchNormal  (None, 80, 80, 96)  384         ['block_1_expand[0][0]']         
 ization)                                                                                         
                                                                                                  
 block_1_expand_relu (ReLU)     (None, 80, 80, 96)   0           ['block_1_expand_BN[0][0]']      
                                                                                                  
 block_1_pad (ZeroPadding2D)    (None, 81, 81, 96)   0           ['block_1_expand_relu[0][0]']    
                                                                                                  
 block_1_depthwise (DepthwiseCo  (None, 40, 40, 96)  864         ['block_1_pad[0][0]']            
 nv2D)                                                                                            
                                                                                                  
 block_1_depthwise_BN (BatchNor  (None, 40, 40, 96)  384         ['block_1_depthwise[0][0]']      
 malization)                                                                                      
                                                                                                  
 block_1_depthwise_relu (ReLU)  (None, 40, 40, 96)   0           ['block_1_depthwise_BN[0][0]']   
                                                                                                  
 block_1_project (Conv2D)       (None, 40, 40, 24)   2304        ['block_1_depthwise_relu[0][0]'] 
                                                                                                  
 block_1_project_BN (BatchNorma  (None, 40, 40, 24)  96          ['block_1_project[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_2_expand (Conv2D)        (None, 40, 40, 144)  3456        ['block_1_project_BN[0][0]']     
                                                                                                  
 block_2_expand_BN (BatchNormal  (None, 40, 40, 144)  576        ['block_2_expand[0][0]']         
 ization)                                                                                         
                                                                                                  
 block_2_expand_relu (ReLU)     (None, 40, 40, 144)  0           ['block_2_expand_BN[0][0]']      
                                                                                                  
 block_2_depthwise (DepthwiseCo  (None, 40, 40, 144)  1296       ['block_2_expand_relu[0][0]']    
 nv2D)                                                                                            
                                                                                                  
 block_2_depthwise_BN (BatchNor  (None, 40, 40, 144)  576        ['block_2_depthwise[0][0]']      
 malization)                                                                                      
                                                                                                  
 block_2_depthwise_relu (ReLU)  (None, 40, 40, 144)  0           ['block_2_depthwise_BN[0][0]']   
                                                                                                  
 block_2_project (Conv2D)       (None, 40, 40, 24)   3456        ['block_2_depthwise_relu[0][0]'] 
                                                                                                  
 block_2_project_BN (BatchNorma  (None, 40, 40, 24)  96          ['block_2_project[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_2_add (Add)              (None, 40, 40, 24)   0           ['block_1_project_BN[0][0]',     
                                                                  'block_2_project_BN[0][0]']     
                                                                                                  
 block_3_expand (Conv2D)        (None, 40, 40, 144)  3456        ['block_2_add[0][0]']            
                                                                                                  
 block_3_expand_BN (BatchNormal  (None, 40, 40, 144)  576        ['block_3_expand[0][0]']         
 ization)                                                                                         
                                                                                                  
 block_3_expand_relu (ReLU)     (None, 40, 40, 144)  0           ['block_3_expand_BN[0][0]']      
                                                                                                  
 block_3_pad (ZeroPadding2D)    (None, 41, 41, 144)  0           ['block_3_expand_relu[0][0]']    
                                                                                                  
 block_3_depthwise (DepthwiseCo  (None, 20, 20, 144)  1296       ['block_3_pad[0][0]']            
 nv2D)                                                                                            
                                                                                                  
 block_3_depthwise_BN (BatchNor  (None, 20, 20, 144)  576        ['block_3_depthwise[0][0]']      
 malization)                                                                                      
                                                                                                  
 block_3_depthwise_relu (ReLU)  (None, 20, 20, 144)  0           ['block_3_depthwise_BN[0][0]']   
                                                                                                  
 block_3_project (Conv2D)       (None, 20, 20, 32)   4608        ['block_3_depthwise_relu[0][0]'] 
                                                                                                  
 block_3_project_BN (BatchNorma  (None, 20, 20, 32)  128         ['block_3_project[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_4_expand (Conv2D)        (None, 20, 20, 192)  6144        ['block_3_project_BN[0][0]']     
                                                                                                  
 block_4_expand_BN (BatchNormal  (None, 20, 20, 192)  768        ['block_4_expand[0][0]']         
 ization)                                                                                         
                                                                                                  
 block_4_expand_relu (ReLU)     (None, 20, 20, 192)  0           ['block_4_expand_BN[0][0]']      
                                                                                                  
 block_4_depthwise (DepthwiseCo  (None, 20, 20, 192)  1728       ['block_4_expand_relu[0][0]']    
 nv2D)                                                                                            
                                                                                                  
 block_4_depthwise_BN (BatchNor  (None, 20, 20, 192)  768        ['block_4_depthwise[0][0]']      
 malization)                                                                                      
                                                                                                  
 block_4_depthwise_relu (ReLU)  (None, 20, 20, 192)  0           ['block_4_depthwise_BN[0][0]']   
                                                                                                  
 block_4_project (Conv2D)       (None, 20, 20, 32)   6144        ['block_4_depthwise_relu[0][0]'] 
                                                                                                  
 block_4_project_BN (BatchNorma  (None, 20, 20, 32)  128         ['block_4_project[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_4_add (Add)              (None, 20, 20, 32)   0           ['block_3_project_BN[0][0]',     
                                                                  'block_4_project_BN[0][0]']     
                                                                                                  
 block_5_expand (Conv2D)        (None, 20, 20, 192)  6144        ['block_4_add[0][0]']            
                                                                                                  
 block_5_expand_BN (BatchNormal  (None, 20, 20, 192)  768        ['block_5_expand[0][0]']         
 ization)                                                                                         
                                                                                                  
 block_5_expand_relu (ReLU)     (None, 20, 20, 192)  0           ['block_5_expand_BN[0][0]']      
                                                                                                  
 block_5_depthwise (DepthwiseCo  (None, 20, 20, 192)  1728       ['block_5_expand_relu[0][0]']    
 nv2D)                                                                                            
                                                                                                  
 block_5_depthwise_BN (BatchNor  (None, 20, 20, 192)  768        ['block_5_depthwise[0][0]']      
 malization)                                                                                      
                                                                                                  
 block_5_depthwise_relu (ReLU)  (None, 20, 20, 192)  0           ['block_5_depthwise_BN[0][0]']   
                                                                                                  
 block_5_project (Conv2D)       (None, 20, 20, 32)   6144        ['block_5_depthwise_relu[0][0]'] 
                                                                                                  
 block_5_project_BN (BatchNorma  (None, 20, 20, 32)  128         ['block_5_project[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_5_add (Add)              (None, 20, 20, 32)   0           ['block_4_add[0][0]',            
                                                                  'block_5_project_BN[0][0]']     
                                                                                                  
 block_6_expand (Conv2D)        (None, 20, 20, 192)  6144        ['block_5_add[0][0]']            
                                                                                                  
 block_6_expand_BN (BatchNormal  (None, 20, 20, 192)  768        ['block_6_expand[0][0]']         
 ization)                                                                                         
                                                                                                  
 block_6_expand_relu (ReLU)     (None, 20, 20, 192)  0           ['block_6_expand_BN[0][0]']      
                                                                                                  
 block_6_pad (ZeroPadding2D)    (None, 21, 21, 192)  0           ['block_6_expand_relu[0][0]']    
                                                                                                  
 block_6_depthwise (DepthwiseCo  (None, 10, 10, 192)  1728       ['block_6_pad[0][0]']            
 nv2D)                                                                                            
                                                                                                  
 block_6_depthwise_BN (BatchNor  (None, 10, 10, 192)  768        ['block_6_depthwise[0][0]']      
 malization)                                                                                      
                                                                                                  
 block_6_depthwise_relu (ReLU)  (None, 10, 10, 192)  0           ['block_6_depthwise_BN[0][0]']   
                                                                                                  
 block_6_project (Conv2D)       (None, 10, 10, 64)   12288       ['block_6_depthwise_relu[0][0]'] 
                                                                                                  
 block_6_project_BN (BatchNorma  (None, 10, 10, 64)  256         ['block_6_project[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_7_expand (Conv2D)        (None, 10, 10, 384)  24576       ['block_6_project_BN[0][0]']     
                                                                                                  
 block_7_expand_BN (BatchNormal  (None, 10, 10, 384)  1536       ['block_7_expand[0][0]']         
 ization)                                                                                         
                                                                                                  
 block_7_expand_relu (ReLU)     (None, 10, 10, 384)  0           ['block_7_expand_BN[0][0]']      
                                                                                                  
 block_7_depthwise (DepthwiseCo  (None, 10, 10, 384)  3456       ['block_7_expand_relu[0][0]']    
 nv2D)                                                                                            
                                                                                                  
 block_7_depthwise_BN (BatchNor  (None, 10, 10, 384)  1536       ['block_7_depthwise[0][0]']      
 malization)                                                                                      
                                                                                                  
 block_7_depthwise_relu (ReLU)  (None, 10, 10, 384)  0           ['block_7_depthwise_BN[0][0]']   
                                                                                                  
 block_7_project (Conv2D)       (None, 10, 10, 64)   24576       ['block_7_depthwise_relu[0][0]'] 
                                                                                                  
 block_7_project_BN (BatchNorma  (None, 10, 10, 64)  256         ['block_7_project[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_7_add (Add)              (None, 10, 10, 64)   0           ['block_6_project_BN[0][0]',     
                                                                  'block_7_project_BN[0][0]']     
                                                                                                  
 block_8_expand (Conv2D)        (None, 10, 10, 384)  24576       ['block_7_add[0][0]']            
                                                                                                  
 block_8_expand_BN (BatchNormal  (None, 10, 10, 384)  1536       ['block_8_expand[0][0]']         
 ization)                                                                                         
                                                                                                  
 block_8_expand_relu (ReLU)     (None, 10, 10, 384)  0           ['block_8_expand_BN[0][0]']      
                                                                                                  
 block_8_depthwise (DepthwiseCo  (None, 10, 10, 384)  3456       ['block_8_expand_relu[0][0]']    
 nv2D)                                                                                            
                                                                                                  
 block_8_depthwise_BN (BatchNor  (None, 10, 10, 384)  1536       ['block_8_depthwise[0][0]']      
 malization)                                                                                      
                                                                                                  
 block_8_depthwise_relu (ReLU)  (None, 10, 10, 384)  0           ['block_8_depthwise_BN[0][0]']   
                                                                                                  
 block_8_project (Conv2D)       (None, 10, 10, 64)   24576       ['block_8_depthwise_relu[0][0]'] 
                                                                                                  
 block_8_project_BN (BatchNorma  (None, 10, 10, 64)  256         ['block_8_project[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_8_add (Add)              (None, 10, 10, 64)   0           ['block_7_add[0][0]',            
                                                                  'block_8_project_BN[0][0]']     
                                                                                                  
 block_9_expand (Conv2D)        (None, 10, 10, 384)  24576       ['block_8_add[0][0]']            
                                                                                                  
 block_9_expand_BN (BatchNormal  (None, 10, 10, 384)  1536       ['block_9_expand[0][0]']         
 ization)                                                                                         
                                                                                                  
 block_9_expand_relu (ReLU)     (None, 10, 10, 384)  0           ['block_9_expand_BN[0][0]']      
                                                                                                  
 block_9_depthwise (DepthwiseCo  (None, 10, 10, 384)  3456       ['block_9_expand_relu[0][0]']    
 nv2D)                                                                                            
                                                                                                  
 block_9_depthwise_BN (BatchNor  (None, 10, 10, 384)  1536       ['block_9_depthwise[0][0]']      
 malization)                                                                                      
                                                                                                  
 block_9_depthwise_relu (ReLU)  (None, 10, 10, 384)  0           ['block_9_depthwise_BN[0][0]']   
                                                                                                  
 block_9_project (Conv2D)       (None, 10, 10, 64)   24576       ['block_9_depthwise_relu[0][0]'] 
                                                                                                  
 block_9_project_BN (BatchNorma  (None, 10, 10, 64)  256         ['block_9_project[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_9_add (Add)              (None, 10, 10, 64)   0           ['block_8_add[0][0]',            
                                                                  'block_9_project_BN[0][0]']     
                                                                                                  
 block_10_expand (Conv2D)       (None, 10, 10, 384)  24576       ['block_9_add[0][0]']            
                                                                                                  
 block_10_expand_BN (BatchNorma  (None, 10, 10, 384)  1536       ['block_10_expand[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_10_expand_relu (ReLU)    (None, 10, 10, 384)  0           ['block_10_expand_BN[0][0]']     
                                                                                                  
 block_10_depthwise (DepthwiseC  (None, 10, 10, 384)  3456       ['block_10_expand_relu[0][0]']   
 onv2D)                                                                                           
                                                                                                  
 block_10_depthwise_BN (BatchNo  (None, 10, 10, 384)  1536       ['block_10_depthwise[0][0]']     
 rmalization)                                                                                     
                                                                                                  
 block_10_depthwise_relu (ReLU)  (None, 10, 10, 384)  0          ['block_10_depthwise_BN[0][0]']  
                                                                                                  
 block_10_project (Conv2D)      (None, 10, 10, 96)   36864       ['block_10_depthwise_relu[0][0]']
                                                                                                  
 block_10_project_BN (BatchNorm  (None, 10, 10, 96)  384         ['block_10_project[0][0]']       
 alization)                                                                                       
                                                                                                  
 block_11_expand (Conv2D)       (None, 10, 10, 576)  55296       ['block_10_project_BN[0][0]']    
                                                                                                  
 block_11_expand_BN (BatchNorma  (None, 10, 10, 576)  2304       ['block_11_expand[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_11_expand_relu (ReLU)    (None, 10, 10, 576)  0           ['block_11_expand_BN[0][0]']     
                                                                                                  
 block_11_depthwise (DepthwiseC  (None, 10, 10, 576)  5184       ['block_11_expand_relu[0][0]']   
 onv2D)                                                                                           
                                                                                                  
 block_11_depthwise_BN (BatchNo  (None, 10, 10, 576)  2304       ['block_11_depthwise[0][0]']     
 rmalization)                                                                                     
                                                                                                  
 block_11_depthwise_relu (ReLU)  (None, 10, 10, 576)  0          ['block_11_depthwise_BN[0][0]']  
                                                                                                  
 block_11_project (Conv2D)      (None, 10, 10, 96)   55296       ['block_11_depthwise_relu[0][0]']
                                                                                                  
 block_11_project_BN (BatchNorm  (None, 10, 10, 96)  384         ['block_11_project[0][0]']       
 alization)                                                                                       
                                                                                                  
 block_11_add (Add)             (None, 10, 10, 96)   0           ['block_10_project_BN[0][0]',    
                                                                  'block_11_project_BN[0][0]']    
                                                                                                  
 block_12_expand (Conv2D)       (None, 10, 10, 576)  55296       ['block_11_add[0][0]']           
                                                                                                  
 block_12_expand_BN (BatchNorma  (None, 10, 10, 576)  2304       ['block_12_expand[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_12_expand_relu (ReLU)    (None, 10, 10, 576)  0           ['block_12_expand_BN[0][0]']     
                                                                                                  
 block_12_depthwise (DepthwiseC  (None, 10, 10, 576)  5184       ['block_12_expand_relu[0][0]']   
 onv2D)                                                                                           
                                                                                                  
 block_12_depthwise_BN (BatchNo  (None, 10, 10, 576)  2304       ['block_12_depthwise[0][0]']     
 rmalization)                                                                                     
                                                                                                  
 block_12_depthwise_relu (ReLU)  (None, 10, 10, 576)  0          ['block_12_depthwise_BN[0][0]']  
                                                                                                  
 block_12_project (Conv2D)      (None, 10, 10, 96)   55296       ['block_12_depthwise_relu[0][0]']
                                                                                                  
 block_12_project_BN (BatchNorm  (None, 10, 10, 96)  384         ['block_12_project[0][0]']       
 alization)                                                                                       
                                                                                                  
 block_12_add (Add)             (None, 10, 10, 96)   0           ['block_11_add[0][0]',           
                                                                  'block_12_project_BN[0][0]']    
                                                                                                  
 block_13_expand (Conv2D)       (None, 10, 10, 576)  55296       ['block_12_add[0][0]']           
                                                                                                  
 block_13_expand_BN (BatchNorma  (None, 10, 10, 576)  2304       ['block_13_expand[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_13_expand_relu (ReLU)    (None, 10, 10, 576)  0           ['block_13_expand_BN[0][0]']     
                                                                                                  
 block_13_pad (ZeroPadding2D)   (None, 11, 11, 576)  0           ['block_13_expand_relu[0][0]']   
                                                                                                  
 block_13_depthwise (DepthwiseC  (None, 5, 5, 576)   5184        ['block_13_pad[0][0]']           
 onv2D)                                                                                           
                                                                                                  
 block_13_depthwise_BN (BatchNo  (None, 5, 5, 576)   2304        ['block_13_depthwise[0][0]']     
 rmalization)                                                                                     
                                                                                                  
 block_13_depthwise_relu (ReLU)  (None, 5, 5, 576)   0           ['block_13_depthwise_BN[0][0]']  
                                                                                                  
 block_13_project (Conv2D)      (None, 5, 5, 160)    92160       ['block_13_depthwise_relu[0][0]']
                                                                                                  
 block_13_project_BN (BatchNorm  (None, 5, 5, 160)   640         ['block_13_project[0][0]']       
 alization)                                                                                       
                                                                                                  
 block_14_expand (Conv2D)       (None, 5, 5, 960)    153600      ['block_13_project_BN[0][0]']    
                                                                                                  
 block_14_expand_BN (BatchNorma  (None, 5, 5, 960)   3840        ['block_14_expand[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_14_expand_relu (ReLU)    (None, 5, 5, 960)    0           ['block_14_expand_BN[0][0]']     
                                                                                                  
 block_14_depthwise (DepthwiseC  (None, 5, 5, 960)   8640        ['block_14_expand_relu[0][0]']   
 onv2D)                                                                                           
                                                                                                  
 block_14_depthwise_BN (BatchNo  (None, 5, 5, 960)   3840        ['block_14_depthwise[0][0]']     
 rmalization)                                                                                     
                                                                                                  
 block_14_depthwise_relu (ReLU)  (None, 5, 5, 960)   0           ['block_14_depthwise_BN[0][0]']  
                                                                                                  
 block_14_project (Conv2D)      (None, 5, 5, 160)    153600      ['block_14_depthwise_relu[0][0]']
                                                                                                  
 block_14_project_BN (BatchNorm  (None, 5, 5, 160)   640         ['block_14_project[0][0]']       
 alization)                                                                                       
                                                                                                  
 block_14_add (Add)             (None, 5, 5, 160)    0           ['block_13_project_BN[0][0]',    
                                                                  'block_14_project_BN[0][0]']    
                                                                                                  
 block_15_expand (Conv2D)       (None, 5, 5, 960)    153600      ['block_14_add[0][0]']           
                                                                                                  
 block_15_expand_BN (BatchNorma  (None, 5, 5, 960)   3840        ['block_15_expand[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_15_expand_relu (ReLU)    (None, 5, 5, 960)    0           ['block_15_expand_BN[0][0]']     
                                                                                                  
 block_15_depthwise (DepthwiseC  (None, 5, 5, 960)   8640        ['block_15_expand_relu[0][0]']   
 onv2D)                                                                                           
                                                                                                  
 block_15_depthwise_BN (BatchNo  (None, 5, 5, 960)   3840        ['block_15_depthwise[0][0]']     
 rmalization)                                                                                     
                                                                                                  
 block_15_depthwise_relu (ReLU)  (None, 5, 5, 960)   0           ['block_15_depthwise_BN[0][0]']  
                                                                                                  
 block_15_project (Conv2D)      (None, 5, 5, 160)    153600      ['block_15_depthwise_relu[0][0]']
                                                                                                  
 block_15_project_BN (BatchNorm  (None, 5, 5, 160)   640         ['block_15_project[0][0]']       
 alization)                                                                                       
                                                                                                  
 block_15_add (Add)             (None, 5, 5, 160)    0           ['block_14_add[0][0]',           
                                                                  'block_15_project_BN[0][0]']    
                                                                                                  
 block_16_expand (Conv2D)       (None, 5, 5, 960)    153600      ['block_15_add[0][0]']           
                                                                                                  
 block_16_expand_BN (BatchNorma  (None, 5, 5, 960)   3840        ['block_16_expand[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_16_expand_relu (ReLU)    (None, 5, 5, 960)    0           ['block_16_expand_BN[0][0]']     
                                                                                                  
 block_16_depthwise (DepthwiseC  (None, 5, 5, 960)   8640        ['block_16_expand_relu[0][0]']   
 onv2D)                                                                                           
                                                                                                  
 block_16_depthwise_BN (BatchNo  (None, 5, 5, 960)   3840        ['block_16_depthwise[0][0]']     
 rmalization)                                                                                     
                                                                                                  
 block_16_depthwise_relu (ReLU)  (None, 5, 5, 960)   0           ['block_16_depthwise_BN[0][0]']  
                                                                                                  
 block_16_project (Conv2D)      (None, 5, 5, 320)    307200      ['block_16_depthwise_relu[0][0]']
                                                                                                  
 block_16_project_BN (BatchNorm  (None, 5, 5, 320)   1280        ['block_16_project[0][0]']       
 alization)                                                                                       
                                                                                                  
 Conv_1 (Conv2D)                (None, 5, 5, 1280)   409600      ['block_16_project_BN[0][0]']    
                                                                                                  
 Conv_1_bn (BatchNormalization)  (None, 5, 5, 1280)  5120        ['Conv_1[0][0]']                 
                                                                                                  
 out_relu (ReLU)                (None, 5, 5, 1280)   0           ['Conv_1_bn[0][0]']              
                                                                                                  
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984
__________________________________________________________________________________________________

เพิ่มหัวการจัดประเภท

ในการสร้างการคาดคะเนจากกลุ่มคุณลักษณะ ให้หาค่าเฉลี่ยของตำแหน่งเชิงพื้นที่ 5x5 โดยใช้เลเยอร์ tf.keras.layers.GlobalAveragePooling2D เพื่อแปลงคุณสมบัติเป็นเวกเตอร์องค์ประกอบ 1280 เดียวต่อภาพ

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
(32, 1280)
ตัวยึดตำแหน่ง23

ใช้เลเยอร์ tf.keras.layers.Dense เพื่อแปลงคุณสมบัติเหล่านี้เป็นการคาดคะเนภาพเดียวต่อภาพ คุณไม่จำเป็นต้องมีฟังก์ชันการเปิดใช้งานที่นี่ เนื่องจากการคาดการณ์นี้จะถือเป็น logit หรือค่าการทำนายแบบดิบ ตัวเลขบวกทำนายชั้น 1 ตัวเลขติดลบทำนายชั้น 0

prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)
(32, 1)

สร้างโมเดลโดยเชื่อมโยงการเพิ่มข้อมูล การปรับขนาด base_model และเลเยอร์ตัวแยกคุณลักษณะโดยใช้ Keras Functional API ตามที่กล่าวไว้ก่อนหน้านี้ ให้ใช้ training=False เนื่องจากโมเดลของเรามีเลเยอร์ BatchNormalization

inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

รวบรวมโมเดล

รวบรวมโมเดลก่อนฝึก เนื่องจากมีสองคลาส ให้ใช้ tf.keras.losses.BinaryCrossentropy loss ด้วย from_logits=True เนื่องจากโมเดลจัดเตรียมเอาต์พุตเชิงเส้น

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 160, 160, 3)]     0         
                                                                 
 sequential (Sequential)     (None, 160, 160, 3)       0         
                                                                 
 tf.math.truediv (TFOpLambda  (None, 160, 160, 3)      0         
 )                                                               
                                                                 
 tf.math.subtract (TFOpLambd  (None, 160, 160, 3)      0         
 a)                                                              
                                                                 
 mobilenetv2_1.00_160 (Funct  (None, 5, 5, 1280)       2257984   
 ional)                                                          
                                                                 
 global_average_pooling2d (G  (None, 1280)             0         
 lobalAveragePooling2D)                                          
                                                                 
 dropout (Dropout)           (None, 1280)              0         
                                                                 
 dense (Dense)               (None, 1)                 1281      
                                                                 
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

พารามิเตอร์ 2.5 ล้านรายการใน MobileNet ถูกระงับ แต่มีพารามิเตอร์ที่ฝึกได้ 1.2 พันรายการในเลเยอร์หนาแน่น สิ่งเหล่านี้ถูกแบ่งระหว่างวัตถุสอง tf.Variable คือน้ำหนักและอคติ

len(model.trainable_variables)
2

ฝึกโมเดล

หลังจากการฝึกอบรมเป็นเวลา 10 ยุค คุณควรเห็นความแม่นยำ ~94% ในชุดการตรวจสอบความถูกต้อง

initial_epochs = 10

loss0, accuracy0 = model.evaluate(validation_dataset)
26/26 [==============================] - 2s 16ms/step - loss: 0.7428 - accuracy: 0.5186
print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))
initial loss: 0.74
initial accuracy: 0.52
history = model.fit(train_dataset,
                    epochs=initial_epochs,
                    validation_data=validation_dataset)
Epoch 1/10
63/63 [==============================] - 4s 23ms/step - loss: 0.6804 - accuracy: 0.5680 - val_loss: 0.4981 - val_accuracy: 0.7054
Epoch 2/10
63/63 [==============================] - 1s 22ms/step - loss: 0.5044 - accuracy: 0.7170 - val_loss: 0.3598 - val_accuracy: 0.8144
Epoch 3/10
63/63 [==============================] - 1s 21ms/step - loss: 0.4109 - accuracy: 0.7845 - val_loss: 0.2810 - val_accuracy: 0.8861
Epoch 4/10
63/63 [==============================] - 1s 21ms/step - loss: 0.3285 - accuracy: 0.8445 - val_loss: 0.2256 - val_accuracy: 0.9208
Epoch 5/10
63/63 [==============================] - 1s 21ms/step - loss: 0.3108 - accuracy: 0.8555 - val_loss: 0.1986 - val_accuracy: 0.9307
Epoch 6/10
63/63 [==============================] - 1s 21ms/step - loss: 0.2659 - accuracy: 0.8855 - val_loss: 0.1703 - val_accuracy: 0.9418
Epoch 7/10
63/63 [==============================] - 1s 21ms/step - loss: 0.2459 - accuracy: 0.8935 - val_loss: 0.1495 - val_accuracy: 0.9517
Epoch 8/10
63/63 [==============================] - 1s 21ms/step - loss: 0.2315 - accuracy: 0.8950 - val_loss: 0.1454 - val_accuracy: 0.9542
Epoch 9/10
63/63 [==============================] - 1s 21ms/step - loss: 0.2204 - accuracy: 0.9030 - val_loss: 0.1326 - val_accuracy: 0.9592
Epoch 10/10
63/63 [==============================] - 1s 21ms/step - loss: 0.2180 - accuracy: 0.9115 - val_loss: 0.1215 - val_accuracy: 0.9604

เส้นโค้งการเรียนรู้

มาดูเส้นโค้งการเรียนรู้ของการฝึกและความแม่นยำ/การสูญเสียการตรวจสอบเมื่อใช้โมเดลพื้นฐานของ MobileNetV2 เป็นตัวแยกคุณลักษณะแบบตายตัว

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

png

ในระดับที่น้อยกว่า เป็นเพราะตัววัดการฝึกอบรมรายงานค่าเฉลี่ยสำหรับยุคนั้น ในขณะที่ตัววัดการตรวจสอบความถูกต้องจะได้รับการประเมินหลังจากยุคนั้น ดังนั้น ตัววัดการตรวจสอบความถูกต้องจะเห็นแบบจำลองที่ได้รับการฝึกนานกว่าเล็กน้อย

ปรับจูน

ในการทดสอบการแยกคุณลักษณะ คุณฝึกเพียงไม่กี่เลเยอร์บนโมเดลพื้นฐานของ MobileNetV2 น้ำหนักของเครือข่ายที่ฝึกไว้ล่วงหน้า ไม่ได้ รับการอัพเดตระหว่างการฝึก

วิธีหนึ่งในการเพิ่มประสิทธิภาพให้ดียิ่งขึ้นไปอีกคือการฝึก (หรือ "ปรับแต่ง") ตุ้มน้ำหนักของเลเยอร์บนสุดของโมเดลที่ฝึกไว้ล่วงหน้าควบคู่ไปกับการฝึกของตัวแยกประเภทที่คุณเพิ่ม กระบวนการฝึกอบรมจะบังคับให้ปรับน้ำหนักจากแผนที่คุณลักษณะทั่วไปเป็นคุณลักษณะที่เกี่ยวข้องกับชุดข้อมูลโดยเฉพาะ

นอกจากนี้ คุณควรพยายามปรับแต่งเลเยอร์บนสุดจำนวนเล็กน้อยแทนที่จะปรับรุ่น MobileNet ทั้งหมด ในเครือข่ายแบบ Convolutional ส่วนใหญ่ ยิ่งชั้นสูงเท่าไหร่ก็ยิ่งมีความพิเศษมากขึ้นเท่านั้น เลเยอร์สองสามชั้นแรกจะเรียนรู้คุณลักษณะทั่วไปที่เรียบง่ายและทั่วถึง ซึ่งสามารถสรุปได้ทั่วไปกับรูปภาพเกือบทุกประเภท เมื่อคุณสูงขึ้น ฟีเจอร์ต่างๆ จะมีความเฉพาะเจาะจงมากขึ้นสำหรับชุดข้อมูลที่โมเดลได้รับการฝึกอบรม เป้าหมายของการปรับแต่งแบบละเอียดคือการปรับคุณลักษณะเฉพาะเหล่านี้ให้ทำงานกับชุดข้อมูลใหม่ แทนที่จะเขียนทับการเรียนรู้ทั่วไป

ยกเลิกการตรึงชั้นบนสุดของโมเดล

สิ่งที่คุณต้องทำคือยกเลิกการตรึง base_model และตั้งค่าชั้นล่างให้ไม่สามารถฝึกได้ จากนั้น คุณควรคอมไพล์โมเดลใหม่ (จำเป็นสำหรับการเปลี่ยนแปลงเหล่านี้เพื่อให้มีผล) และดำเนินการฝึกต่อ

base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable = False
Number of layers in the base model:  154

รวบรวมโมเดล

ขณะที่คุณกำลังฝึกโมเดลที่ใหญ่กว่ามากและต้องการอ่านตุ้มน้ำหนักที่ฝึกไว้ล่วงหน้า สิ่งสำคัญคือต้องใช้อัตราการเรียนรู้ที่ต่ำกว่าในขั้นตอนนี้ มิฉะนั้น โมเดลของคุณอาจเกินพอดีอย่างรวดเร็ว

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate/10),
              metrics=['accuracy'])
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 160, 160, 3)]     0         
                                                                 
 sequential (Sequential)     (None, 160, 160, 3)       0         
                                                                 
 tf.math.truediv (TFOpLambda  (None, 160, 160, 3)      0         
 )                                                               
                                                                 
 tf.math.subtract (TFOpLambd  (None, 160, 160, 3)      0         
 a)                                                              
                                                                 
 mobilenetv2_1.00_160 (Funct  (None, 5, 5, 1280)       2257984   
 ional)                                                          
                                                                 
 global_average_pooling2d (G  (None, 1280)             0         
 lobalAveragePooling2D)                                          
                                                                 
 dropout (Dropout)           (None, 1280)              0         
                                                                 
 dense (Dense)               (None, 1)                 1281      
                                                                 
=================================================================
Total params: 2,259,265
Trainable params: 1,862,721
Non-trainable params: 396,544
_________________________________________________________________
len(model.trainable_variables)
56

ฝึกโมเดลต่อไป

หากคุณฝึกฝนการบรรจบกันก่อนหน้านี้ ขั้นตอนนี้จะช่วยปรับปรุงความแม่นยำของคุณสองสามเปอร์เซ็นต์

fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_dataset,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_dataset)
Epoch 10/20
63/63 [==============================] - 7s 40ms/step - loss: 0.1545 - accuracy: 0.9335 - val_loss: 0.0531 - val_accuracy: 0.9864
Epoch 11/20
63/63 [==============================] - 2s 28ms/step - loss: 0.1161 - accuracy: 0.9540 - val_loss: 0.0500 - val_accuracy: 0.9814
Epoch 12/20
63/63 [==============================] - 2s 28ms/step - loss: 0.1125 - accuracy: 0.9525 - val_loss: 0.0379 - val_accuracy: 0.9876
Epoch 13/20
63/63 [==============================] - 2s 28ms/step - loss: 0.0891 - accuracy: 0.9625 - val_loss: 0.0472 - val_accuracy: 0.9889
Epoch 14/20
63/63 [==============================] - 2s 28ms/step - loss: 0.0844 - accuracy: 0.9680 - val_loss: 0.0478 - val_accuracy: 0.9889
Epoch 15/20
63/63 [==============================] - 2s 28ms/step - loss: 0.0857 - accuracy: 0.9645 - val_loss: 0.0354 - val_accuracy: 0.9839
Epoch 16/20
63/63 [==============================] - 2s 28ms/step - loss: 0.0785 - accuracy: 0.9690 - val_loss: 0.0449 - val_accuracy: 0.9864
Epoch 17/20
63/63 [==============================] - 2s 28ms/step - loss: 0.0669 - accuracy: 0.9740 - val_loss: 0.0375 - val_accuracy: 0.9839
Epoch 18/20
63/63 [==============================] - 2s 28ms/step - loss: 0.0701 - accuracy: 0.9695 - val_loss: 0.0324 - val_accuracy: 0.9864
Epoch 19/20
63/63 [==============================] - 2s 28ms/step - loss: 0.0636 - accuracy: 0.9760 - val_loss: 0.0465 - val_accuracy: 0.9790
Epoch 20/20
63/63 [==============================] - 2s 29ms/step - loss: 0.0585 - accuracy: 0.9765 - val_loss: 0.0392 - val_accuracy: 0.9851

มาดูเส้นโค้งการเรียนรู้ของการฝึกและความแม่นยำ/การสูญเสียการตรวจสอบเมื่อทำการปรับแต่งสองสามเลเยอร์สุดท้ายของโมเดลพื้นฐาน MobileNetV2 และฝึกตัวแยกประเภทที่อยู่ด้านบน การสูญเสียการตรวจสอบจะสูงกว่าการสูญเสียการฝึกอบรมมาก ดังนั้นคุณอาจได้รับการฝึกฝนมากเกินไป

คุณอาจได้รับการตั้งค่ามากเกินไปเนื่องจากชุดการฝึกใหม่มีขนาดค่อนข้างเล็กและคล้ายกับชุดข้อมูล MobileNetV2 ดั้งเดิม

หลังจากปรับแต่งโมเดลอย่างละเอียดแล้ว เกือบถึงความแม่นยำถึง 98% ในชุดการตรวจสอบ

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

png

การประเมินและการทำนาย

ขั้นสุดท้ายคุณสามารถตรวจสอบประสิทธิภาพของแบบจำลองกับข้อมูลใหม่ได้โดยใช้ชุดทดสอบ

loss, accuracy = model.evaluate(test_dataset)
print('Test accuracy :', accuracy)
6/6 [==============================] - 0s 13ms/step - loss: 0.0281 - accuracy: 0.9948
Test accuracy : 0.9947916865348816
ตัวยึดตำแหน่ง52

และตอนนี้คุณก็พร้อมแล้วที่จะใช้โมเดลนี้ในการทำนายว่าสัตว์เลี้ยงของคุณคือแมวหรือสุนัข

# Retrieve a batch of images from the test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()

# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)

print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)

plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(image_batch[i].astype("uint8"))
  plt.title(class_names[predictions[i]])
  plt.axis("off")
Predictions:
 [0 1 1 1 1 0 1 1 1 0 1 1 0 1 1 1 0 0 0 1 0 1 0 0 1 1 1 0 0 0 1 0]
Labels:
 [0 1 1 1 1 0 1 1 1 0 1 1 0 1 1 1 0 0 0 1 0 1 0 0 1 1 1 0 0 0 1 0]

png

สรุป

  • การใช้โมเดลที่ฝึกไว้ล่วงหน้าสำหรับการดึงข้อมูลคุณลักษณะ : เมื่อทำงานกับชุดข้อมูลขนาดเล็ก เป็นเรื่องปกติที่จะใช้ประโยชน์จากคุณลักษณะที่เรียนรู้โดยโมเดลที่ได้รับการฝึกฝนบนชุดข้อมูลขนาดใหญ่กว่าในโดเมนเดียวกัน ซึ่งทำได้โดยการสร้างอินสแตนซ์ของโมเดลที่ฝึกอบรมไว้ล่วงหน้าและเพิ่มตัวแยกประเภทที่เชื่อมต่ออย่างสมบูรณ์ที่ด้านบน โมเดลที่ฝึกไว้ล่วงหน้านั้น "หยุดนิ่ง" และเฉพาะตุ้มน้ำหนักของลักษณนามเท่านั้นที่จะได้รับการอัปเดตระหว่างการฝึก ในกรณีนี้ Convolutional Base จะแยกคุณลักษณะทั้งหมดที่เกี่ยวข้องกับแต่ละภาพ และคุณเพิ่งฝึกตัวแยกประเภทที่กำหนดคลาสของรูปภาพตามชุดของคุณลักษณะที่แยกออกมานั้น

  • การปรับแต่งโมเดลที่ฝึกไว้ล่วงหน้าอย่างละเอียด : เพื่อปรับปรุงประสิทธิภาพเพิ่มเติม เราอาจต้องการปรับเปลี่ยนเลเยอร์ระดับบนสุดของโมเดลที่ได้รับการฝึกอบรมล่วงหน้าไปยังชุดข้อมูลใหม่ผ่านการปรับแต่งแบบละเอียด ในกรณีนี้ คุณปรับตุ้มน้ำหนักของคุณเพื่อให้โมเดลเรียนรู้คุณลักษณะระดับสูงเฉพาะสำหรับชุดข้อมูล เทคนิคนี้มักจะแนะนำเมื่อชุดข้อมูลการฝึกมีขนาดใหญ่และคล้ายกันมากกับชุดข้อมูลดั้งเดิมที่มีการฝึกแบบจำลองล่วงหน้า

หากต้องการเรียนรู้เพิ่มเติม โปรดไปที่ คู่มือการโอนย้ายการเรียนรู้

# MIT License
#
# Copyright (c) 2017 François Chollet                                                                                                                    # IGNORE_COPYRIGHT: cleared by OSS licensing
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.