728x90
반응형

Transfer learning 정리

1. Feature extraction
- CNN (Convolution 연산 부분만)을 활용하여 추출한다 => window 연산을 하기 때문에 입력 개수가 정해지지 않아도 된다 
- input_shape을 정하지 않아도 되지만 tensorflow/keras에서 build하기 위해 고정하는 경우가 있다
- Feature extracion을 사용하는 경우 전체를 쓰는 것보다 성능이 일반적으로 좋지 않다. 하지만 데이터가 적을 때 최소한 학습한 내용이 있기 때문에 조금의 영향력을 가질 수 있다 (fine tuning이 필요하다)
2. Fine tuning
- Learning rate를 크게하면 catastropic forgetting이 발생되기 때문에 learning rate를 작게 하여 재학습 시켜야 한다 

 

U-net 구현하기

import tensorflow as tf 
import matplotlib.pyplot as plt
import numpy as np
base_model = tf.keras.applications.MobileNetV2(include_top=False)
base_model.summary()
Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_12 (InputLayer)           [(None, None, None,  0                                            
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, None, None, 3 864         input_12[0][0]                   
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, None, None, 3 128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, None, None, 3 0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, None, None, 3 288         Conv1_relu[0][0]                 
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, None, None, 3 128         expanded_conv_depthwise[0][0]    
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, None, None, 3 0           expanded_conv_depthwise_BN[0][0] 
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, None, None, 1 512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, None, None, 1 64          expanded_conv_project[0][0]      
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, None, None, 9 1536        expanded_conv_project_BN[0][0]   
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, None, None, 9 384         block_1_expand[0][0]             
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, None, None, 9 0           block_1_expand_BN[0][0]          
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, None, None, 9 0           block_1_expand_relu[0][0]        
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, None, None, 9 864         block_1_pad[0][0]                
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, None, None, 9 384         block_1_depthwise[0][0]          
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU)   (None, None, None, 9 0           block_1_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_1_project (Conv2D)        (None, None, None, 2 2304        block_1_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, None, None, 2 96          block_1_project[0][0]            
__________________________________________________________________________________________________
block_2_expand (Conv2D)         (None, None, None, 1 3456        block_1_project_BN[0][0]         
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, None, None, 1 576         block_2_expand[0][0]             
__________________________________________________________________________________________________
block_2_expand_relu (ReLU)      (None, None, None, 1 0           block_2_expand_BN[0][0]          
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, None, None, 1 1296        block_2_expand_relu[0][0]        
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, None, None, 1 576         block_2_depthwise[0][0]          
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU)   (None, None, None, 1 0           block_2_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_2_project (Conv2D)        (None, None, None, 2 3456        block_2_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_2_project_BN (BatchNormal (None, None, None, 2 96          block_2_project[0][0]            
__________________________________________________________________________________________________
block_2_add (Add)               (None, None, None, 2 0           block_1_project_BN[0][0]         
                                                                 block_2_project_BN[0][0]         
__________________________________________________________________________________________________
block_3_expand (Conv2D)         (None, None, None, 1 3456        block_2_add[0][0]                
__________________________________________________________________________________________________
block_3_expand_BN (BatchNormali (None, None, None, 1 576         block_3_expand[0][0]             
__________________________________________________________________________________________________
block_3_expand_relu (ReLU)      (None, None, None, 1 0           block_3_expand_BN[0][0]          
__________________________________________________________________________________________________
block_3_pad (ZeroPadding2D)     (None, None, None, 1 0           block_3_expand_relu[0][0]        
__________________________________________________________________________________________________
block_3_depthwise (DepthwiseCon (None, None, None, 1 1296        block_3_pad[0][0]                
__________________________________________________________________________________________________
block_3_depthwise_BN (BatchNorm (None, None, None, 1 576         block_3_depthwise[0][0]          
__________________________________________________________________________________________________
block_3_depthwise_relu (ReLU)   (None, None, None, 1 0           block_3_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_3_project (Conv2D)        (None, None, None, 3 4608        block_3_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_3_project_BN (BatchNormal (None, None, None, 3 128         block_3_project[0][0]            
__________________________________________________________________________________________________
block_4_expand (Conv2D)         (None, None, None, 1 6144        block_3_project_BN[0][0]         
__________________________________________________________________________________________________
block_4_expand_BN (BatchNormali (None, None, None, 1 768         block_4_expand[0][0]             
__________________________________________________________________________________________________
block_4_expand_relu (ReLU)      (None, None, None, 1 0           block_4_expand_BN[0][0]          
__________________________________________________________________________________________________
block_4_depthwise (DepthwiseCon (None, None, None, 1 1728        block_4_expand_relu[0][0]        
__________________________________________________________________________________________________
block_4_depthwise_BN (BatchNorm (None, None, None, 1 768         block_4_depthwise[0][0]          
__________________________________________________________________________________________________
block_4_depthwise_relu (ReLU)   (None, None, None, 1 0           block_4_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_4_project (Conv2D)        (None, None, None, 3 6144        block_4_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_4_project_BN (BatchNormal (None, None, None, 3 128         block_4_project[0][0]            
__________________________________________________________________________________________________
block_4_add (Add)               (None, None, None, 3 0           block_3_project_BN[0][0]         
                                                                 block_4_project_BN[0][0]         
__________________________________________________________________________________________________
block_5_expand (Conv2D)         (None, None, None, 1 6144        block_4_add[0][0]                
__________________________________________________________________________________________________
block_5_expand_BN (BatchNormali (None, None, None, 1 768         block_5_expand[0][0]             
__________________________________________________________________________________________________
block_5_expand_relu (ReLU)      (None, None, None, 1 0           block_5_expand_BN[0][0]          
__________________________________________________________________________________________________
block_5_depthwise (DepthwiseCon (None, None, None, 1 1728        block_5_expand_relu[0][0]        
__________________________________________________________________________________________________
block_5_depthwise_BN (BatchNorm (None, None, None, 1 768         block_5_depthwise[0][0]          
__________________________________________________________________________________________________
block_5_depthwise_relu (ReLU)   (None, None, None, 1 0           block_5_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_5_project (Conv2D)        (None, None, None, 3 6144        block_5_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_5_project_BN (BatchNormal (None, None, None, 3 128         block_5_project[0][0]            
__________________________________________________________________________________________________
block_5_add (Add)               (None, None, None, 3 0           block_4_add[0][0]                
                                                                 block_5_project_BN[0][0]         
__________________________________________________________________________________________________
block_6_expand (Conv2D)         (None, None, None, 1 6144        block_5_add[0][0]                
__________________________________________________________________________________________________
block_6_expand_BN (BatchNormali (None, None, None, 1 768         block_6_expand[0][0]             
__________________________________________________________________________________________________
block_6_expand_relu (ReLU)      (None, None, None, 1 0           block_6_expand_BN[0][0]          
__________________________________________________________________________________________________
block_6_pad (ZeroPadding2D)     (None, None, None, 1 0           block_6_expand_relu[0][0]        
__________________________________________________________________________________________________
block_6_depthwise (DepthwiseCon (None, None, None, 1 1728        block_6_pad[0][0]                
__________________________________________________________________________________________________
block_6_depthwise_BN (BatchNorm (None, None, None, 1 768         block_6_depthwise[0][0]          
__________________________________________________________________________________________________
block_6_depthwise_relu (ReLU)   (None, None, None, 1 0           block_6_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_6_project (Conv2D)        (None, None, None, 6 12288       block_6_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_6_project_BN (BatchNormal (None, None, None, 6 256         block_6_project[0][0]            
__________________________________________________________________________________________________
block_7_expand (Conv2D)         (None, None, None, 3 24576       block_6_project_BN[0][0]         
__________________________________________________________________________________________________
block_7_expand_BN (BatchNormali (None, None, None, 3 1536        block_7_expand[0][0]             
__________________________________________________________________________________________________
block_7_expand_relu (ReLU)      (None, None, None, 3 0           block_7_expand_BN[0][0]          
__________________________________________________________________________________________________
block_7_depthwise (DepthwiseCon (None, None, None, 3 3456        block_7_expand_relu[0][0]        
__________________________________________________________________________________________________
block_7_depthwise_BN (BatchNorm (None, None, None, 3 1536        block_7_depthwise[0][0]          
__________________________________________________________________________________________________
block_7_depthwise_relu (ReLU)   (None, None, None, 3 0           block_7_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_7_project (Conv2D)        (None, None, None, 6 24576       block_7_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_7_project_BN (BatchNormal (None, None, None, 6 256         block_7_project[0][0]            
__________________________________________________________________________________________________
block_7_add (Add)               (None, None, None, 6 0           block_6_project_BN[0][0]         
                                                                 block_7_project_BN[0][0]         
__________________________________________________________________________________________________
block_8_expand (Conv2D)         (None, None, None, 3 24576       block_7_add[0][0]                
__________________________________________________________________________________________________
block_8_expand_BN (BatchNormali (None, None, None, 3 1536        block_8_expand[0][0]             
__________________________________________________________________________________________________
block_8_expand_relu (ReLU)      (None, None, None, 3 0           block_8_expand_BN[0][0]          
__________________________________________________________________________________________________
block_8_depthwise (DepthwiseCon (None, None, None, 3 3456        block_8_expand_relu[0][0]        
__________________________________________________________________________________________________
block_8_depthwise_BN (BatchNorm (None, None, None, 3 1536        block_8_depthwise[0][0]          
__________________________________________________________________________________________________
block_8_depthwise_relu (ReLU)   (None, None, None, 3 0           block_8_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_8_project (Conv2D)        (None, None, None, 6 24576       block_8_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_8_project_BN (BatchNormal (None, None, None, 6 256         block_8_project[0][0]            
__________________________________________________________________________________________________
block_8_add (Add)               (None, None, None, 6 0           block_7_add[0][0]                
                                                                 block_8_project_BN[0][0]         
__________________________________________________________________________________________________
block_9_expand (Conv2D)         (None, None, None, 3 24576       block_8_add[0][0]                
__________________________________________________________________________________________________
block_9_expand_BN (BatchNormali (None, None, None, 3 1536        block_9_expand[0][0]             
__________________________________________________________________________________________________
block_9_expand_relu (ReLU)      (None, None, None, 3 0           block_9_expand_BN[0][0]          
__________________________________________________________________________________________________
block_9_depthwise (DepthwiseCon (None, None, None, 3 3456        block_9_expand_relu[0][0]        
__________________________________________________________________________________________________
block_9_depthwise_BN (BatchNorm (None, None, None, 3 1536        block_9_depthwise[0][0]          
__________________________________________________________________________________________________
block_9_depthwise_relu (ReLU)   (None, None, None, 3 0           block_9_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_9_project (Conv2D)        (None, None, None, 6 24576       block_9_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_9_project_BN (BatchNormal (None, None, None, 6 256         block_9_project[0][0]            
__________________________________________________________________________________________________
block_9_add (Add)               (None, None, None, 6 0           block_8_add[0][0]                
                                                                 block_9_project_BN[0][0]         
__________________________________________________________________________________________________
block_10_expand (Conv2D)        (None, None, None, 3 24576       block_9_add[0][0]                
__________________________________________________________________________________________________
block_10_expand_BN (BatchNormal (None, None, None, 3 1536        block_10_expand[0][0]            
__________________________________________________________________________________________________
block_10_expand_relu (ReLU)     (None, None, None, 3 0           block_10_expand_BN[0][0]         
__________________________________________________________________________________________________
block_10_depthwise (DepthwiseCo (None, None, None, 3 3456        block_10_expand_relu[0][0]       
__________________________________________________________________________________________________
block_10_depthwise_BN (BatchNor (None, None, None, 3 1536        block_10_depthwise[0][0]         
__________________________________________________________________________________________________
block_10_depthwise_relu (ReLU)  (None, None, None, 3 0           block_10_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_10_project (Conv2D)       (None, None, None, 9 36864       block_10_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_10_project_BN (BatchNorma (None, None, None, 9 384         block_10_project[0][0]           
__________________________________________________________________________________________________
block_11_expand (Conv2D)        (None, None, None, 5 55296       block_10_project_BN[0][0]        
__________________________________________________________________________________________________
block_11_expand_BN (BatchNormal (None, None, None, 5 2304        block_11_expand[0][0]            
__________________________________________________________________________________________________
block_11_expand_relu (ReLU)     (None, None, None, 5 0           block_11_expand_BN[0][0]         
__________________________________________________________________________________________________
block_11_depthwise (DepthwiseCo (None, None, None, 5 5184        block_11_expand_relu[0][0]       
__________________________________________________________________________________________________
block_11_depthwise_BN (BatchNor (None, None, None, 5 2304        block_11_depthwise[0][0]         
__________________________________________________________________________________________________
block_11_depthwise_relu (ReLU)  (None, None, None, 5 0           block_11_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_11_project (Conv2D)       (None, None, None, 9 55296       block_11_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_11_project_BN (BatchNorma (None, None, None, 9 384         block_11_project[0][0]           
__________________________________________________________________________________________________
block_11_add (Add)              (None, None, None, 9 0           block_10_project_BN[0][0]        
                                                                 block_11_project_BN[0][0]        
__________________________________________________________________________________________________
block_12_expand (Conv2D)        (None, None, None, 5 55296       block_11_add[0][0]               
__________________________________________________________________________________________________
block_12_expand_BN (BatchNormal (None, None, None, 5 2304        block_12_expand[0][0]            
__________________________________________________________________________________________________
block_12_expand_relu (ReLU)     (None, None, None, 5 0           block_12_expand_BN[0][0]         
__________________________________________________________________________________________________
block_12_depthwise (DepthwiseCo (None, None, None, 5 5184        block_12_expand_relu[0][0]       
__________________________________________________________________________________________________
block_12_depthwise_BN (BatchNor (None, None, None, 5 2304        block_12_depthwise[0][0]         
__________________________________________________________________________________________________
block_12_depthwise_relu (ReLU)  (None, None, None, 5 0           block_12_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_12_project (Conv2D)       (None, None, None, 9 55296       block_12_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_12_project_BN (BatchNorma (None, None, None, 9 384         block_12_project[0][0]           
__________________________________________________________________________________________________
block_12_add (Add)              (None, None, None, 9 0           block_11_add[0][0]               
                                                                 block_12_project_BN[0][0]        
__________________________________________________________________________________________________
block_13_expand (Conv2D)        (None, None, None, 5 55296       block_12_add[0][0]               
__________________________________________________________________________________________________
block_13_expand_BN (BatchNormal (None, None, None, 5 2304        block_13_expand[0][0]            
__________________________________________________________________________________________________
block_13_expand_relu (ReLU)     (None, None, None, 5 0           block_13_expand_BN[0][0]         
__________________________________________________________________________________________________
block_13_pad (ZeroPadding2D)    (None, None, None, 5 0           block_13_expand_relu[0][0]       
__________________________________________________________________________________________________
block_13_depthwise (DepthwiseCo (None, None, None, 5 5184        block_13_pad[0][0]               
__________________________________________________________________________________________________
block_13_depthwise_BN (BatchNor (None, None, None, 5 2304        block_13_depthwise[0][0]         
__________________________________________________________________________________________________
block_13_depthwise_relu (ReLU)  (None, None, None, 5 0           block_13_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_13_project (Conv2D)       (None, None, None, 1 92160       block_13_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_13_project_BN (BatchNorma (None, None, None, 1 640         block_13_project[0][0]           
__________________________________________________________________________________________________
block_14_expand (Conv2D)        (None, None, None, 9 153600      block_13_project_BN[0][0]        
__________________________________________________________________________________________________
block_14_expand_BN (BatchNormal (None, None, None, 9 3840        block_14_expand[0][0]            
__________________________________________________________________________________________________
block_14_expand_relu (ReLU)     (None, None, None, 9 0           block_14_expand_BN[0][0]         
__________________________________________________________________________________________________
block_14_depthwise (DepthwiseCo (None, None, None, 9 8640        block_14_expand_relu[0][0]       
__________________________________________________________________________________________________
block_14_depthwise_BN (BatchNor (None, None, None, 9 3840        block_14_depthwise[0][0]         
__________________________________________________________________________________________________
block_14_depthwise_relu (ReLU)  (None, None, None, 9 0           block_14_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_14_project (Conv2D)       (None, None, None, 1 153600      block_14_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_14_project_BN (BatchNorma (None, None, None, 1 640         block_14_project[0][0]           
__________________________________________________________________________________________________
block_14_add (Add)              (None, None, None, 1 0           block_13_project_BN[0][0]        
                                                                 block_14_project_BN[0][0]        
__________________________________________________________________________________________________
block_15_expand (Conv2D)        (None, None, None, 9 153600      block_14_add[0][0]               
__________________________________________________________________________________________________
block_15_expand_BN (BatchNormal (None, None, None, 9 3840        block_15_expand[0][0]            
__________________________________________________________________________________________________
block_15_expand_relu (ReLU)     (None, None, None, 9 0           block_15_expand_BN[0][0]         
__________________________________________________________________________________________________
block_15_depthwise (DepthwiseCo (None, None, None, 9 8640        block_15_expand_relu[0][0]       
__________________________________________________________________________________________________
block_15_depthwise_BN (BatchNor (None, None, None, 9 3840        block_15_depthwise[0][0]         
__________________________________________________________________________________________________
block_15_depthwise_relu (ReLU)  (None, None, None, 9 0           block_15_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_15_project (Conv2D)       (None, None, None, 1 153600      block_15_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_15_project_BN (BatchNorma (None, None, None, 1 640         block_15_project[0][0]           
__________________________________________________________________________________________________
block_15_add (Add)              (None, None, None, 1 0           block_14_add[0][0]               
                                                                 block_15_project_BN[0][0]        
__________________________________________________________________________________________________
block_16_expand (Conv2D)        (None, None, None, 9 153600      block_15_add[0][0]               
__________________________________________________________________________________________________
block_16_expand_BN (BatchNormal (None, None, None, 9 3840        block_16_expand[0][0]            
__________________________________________________________________________________________________
block_16_expand_relu (ReLU)     (None, None, None, 9 0           block_16_expand_BN[0][0]         
__________________________________________________________________________________________________
block_16_depthwise (DepthwiseCo (None, None, None, 9 8640        block_16_expand_relu[0][0]       
__________________________________________________________________________________________________
block_16_depthwise_BN (BatchNor (None, None, None, 9 3840        block_16_depthwise[0][0]         
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU)  (None, None, None, 9 0           block_16_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_16_project (Conv2D)       (None, None, None, 3 307200      block_16_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, None, None, 3 1280        block_16_project[0][0]           
__________________________________________________________________________________________________
Conv_1 (Conv2D)                 (None, None, None, 1 409600      block_16_project_BN[0][0]        
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, None, None, 1 5120        Conv_1[0][0]                     
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, None, None, 1 0           Conv_1_bn[0][0]                  
==================================================================================================
Total params: 2,257,984
Trainable params: 2,223,872
Non-trainable params: 34,112
__________________________________________________________________________________________________
im = tf.keras.preprocessing.image.load_img('ade.jpg')

np.array(im).shape
# (960, 1280, 3)
plt.imshow(im)

# base_model은 input shape이 224,224,3을 받도록 만들어져 있지만 
# 크기를 고정하지 않았기 때문에 input data로 들어갈 수 있다
base_model(np.array(im)[np.newaxis])

<tf.Tensor: shape=(1, 30, 40, 1280), dtype=float32, numpy=
array([[[[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         ...,
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]],

        [[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.02680304],
         [0.        , 0.        , 0.        , ..., 0.        ,
          2.0374901 , 1.3772459 ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          1.1740499 , 0.43281242],
         ...,
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]],

        [[0.        , 0.7047172 , 0.        , ..., 0.        ,
          1.3220952 , 2.8031015 ],
         [0.        , 0.1590641 , 0.        , ..., 0.        ,
          2.0702174 , 3.4394553 ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          1.5714682 , 2.23192   ],
         ...,
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.3797317 ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]],

        ...,

        [[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         ...,
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.6268955 ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.09846598]],

        [[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.70790523, 0.28814578, 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.20785622, 0.        , ..., 0.        ,
          0.        , 0.        ],
         ...,
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]],

        [[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         ...,
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]]]], dtype=float32)>
# input_shape 고정 / mobilenet(pre-trained model)은 데이터가 적을 때 그나마 괜찮은 성능을 보여준다
base_model = tf.keras.applications.MobileNetV2(input_shape=(128,128,3), include_top=False)

base_model(np.array(im)[np.newaxis]) # input_shape크기를 고정했기 때문에 입력값으로 들어가지 않는다
layer_names = [
    'block_1_expand_relu', # 64x64
    'block_3_expand_relu', # 32x32
    'block_6_expand_relu', # 16x16
    'block_13_expand_relu', # 8x8
    'block_16_project', # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]
base_model_outputs

[<KerasTensor: shape=(None, 64, 64, 96) dtype=float32 (created by layer 'block_1_expand_relu')>,
 <KerasTensor: shape=(None, 32, 32, 144) dtype=float32 (created by layer 'block_3_expand_relu')>,
 <KerasTensor: shape=(None, 16, 16, 192) dtype=float32 (created by layer 'block_6_expand_relu')>,
 <KerasTensor: shape=(None, 8, 8, 576) dtype=float32 (created by layer 'block_13_expand_relu')>,
 <KerasTensor: shape=(None, 4, 4, 320) dtype=float32 (created by layer 'block_16_project')>]
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)
down_stack.trainable = False

 

 

Concatenate vs Add

- Concatenate는 구조를 더하는 것 (구조를 더하면 정보가 그대로 유지된다)
- Add는 값을 더하는 것 

 

GAN(Generative Adversarial Network)

GAN 구조

Generator(생성자)와 Discriminator (판별자) 두 개의 모델이 동시에 적대적인 과정으로 학습한다
생성자 G는 실제 데이터 분포를 학습하고, 판별자 D는 원래의 데이터인지 생성자로부터 생성이 된 것인지 구분한다
생성자 G의 학습 과정은 이미지를 잘 생성해서 속일 확률을 높이고 판별자 D가 제대로 구분하는 확률을 높이는 방향으로 학습을 진행한다

GAN은 진짜 같은 가짜 데이터를 generator가 생성하고 진짜 데이터와 비교하여 discriminator가 어느 것이 진짜인지 판별하며 학습해 나가는 모델이다 

 

 

GAN의 종류

 

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

X_train = X_train.reshape(-1,28*28) / 255
X_test = X_test.reshape(-1,28*28) / 255
input_ = tf.keras.Input((10,))
x = tf.keras.layers.Dense(128)(input_)
x = tf.keras.layers.LeakyReLU()(x)
x = tf.keras.layers.Dense(256)(x)
x = tf.keras.layers.LeakyReLU()(x)
x = tf.keras.layers.Dense(512)(x)
x = tf.keras.layers.LeakyReLU()(x)
x = tf.keras.layers.Dense(28*28, activation='sigmoid')(x)

generator = tf.keras.models.Model(input_, x)
input_ = tf.keras.Input((784,))
x = tf.keras.layers.Dense(1024)(input_)
x = tf.keras.layers.LeakyReLU()(x)
x = tf.keras.layers.Dense(512)(x)
x = tf.keras.layers.LeakyReLU()(x)
x = tf.keras.layers.Dense(256)(x)
x = tf.keras.layers.LeakyReLU()(x)
x = tf.keras.layers.Dense(1, activation='sigmoid')(x)

discriminator = tf.keras.models.Model(input_, x)
discriminator.compile(loss='binary_crossentropy',optimizer='adam')

discriminator.trainable = False
gan_input = tf.keras.Input((10,))
x = generator(gan_input)
output = discriminator(x)

GAN = tf.keras.models.Model(gan_input, output)
GAN.compile(loss='binary_crossentropy',optimizer='adam')
GAN.summary()

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         [(None, 10)]              0         
_________________________________________________________________
model (Functional)           (None, 784)               568208    
_________________________________________________________________
model_1 (Functional)         (None, 1)                 1460225   
=================================================================
Total params: 2,028,433
Trainable params: 568,208
Non-trainable params: 1,460,225
_________________________________________________________________
def get_batch(data, batch_size=32):
    batches = []
    for i in range(int(data.shape[0]//batch_size)):
        batch = data[i*batch_size: (i+1)*batch_size]
        batches.append(batch)
    return np.asarray(batches)
    

get_batch(X_train, 100).shape
#(600, 100, 784)
d_losses = []
g_losses = []
for i in range(1,11):
    for j in get_batch(X_train):
        input_noise = np.random.uniform(-1,1, size=[32,10])
        fake = generator.predict(input_noise)
        x_dis = np.concatenate([j, fake])
        y_dis = np.zeros(2*32)
        y_dis[:32] = 1
        
        discriminator.trainable = True 
        d_loss = discriminator.train_on_batch(x_dis, y_dis) # fit 
        discriminator.trainable = False 
        
        noise = np.random.uniform(-1,1,size=[32,10])
        y_gan = np.ones(32)
        g_loss = GAN.train_on_batch(noise, y_gan)
        
    d_losses.append(d_loss)
    g_losses.append(g_loss)
fake = generator.predict(noise).reshape(-1, 28,28)

plt.imshow(fake[0]) # 한 번 학습 했을 때

plt.imshow(fake[1]) # 한 번 학습 했을 때

plt.imshow(fake[0]) # 두 번 학습 했을 때

plt.imshow(fake[1]) # 두 번 학습 했을 때

Train_on_batch

Train_on_batch는 고정된 batch size를 신경쓰지 않고 현재 전달받은 입력 데이터를 활용하여 weight를 업데이트한다
GAN에서는 매번 Generator가 새로운 fake image를 만들기 때문에 epoch마다 새로운 데이터를 넘겨줘야 한다
따라서 train_on_batch를 사용하는 것이 좋다
또한 pre-trained 모델을 단일 배치로 학습시켜야 할 때 train_on_batch를 사용하는 것이 좋다

 

 

반응형

+ Recent posts