728x90
반응형

Image(Object) Classification

이미지 분류 흐름

 

비용 절감 vs 성능 향상

┌─ 사람이 만든 알고리즘 (Rule 기반)

└─ ML

      ┌─ 전통적인 ML

      └─ Deep Learning

            └─ 크기와 위치가 같은 상황에서만 성능이 좋음 (가정이 필요)

                  ┌─ FN

                  └─ CNN

                        └─ Capacity가 크면 수록 성능이 좋다

                              (kernel, filter, layer, perceptron 개수 )

                              └─ 이미지 데이터 개수가 아주 많아야 한다

                                    └─ Augmentation (데이터 뻥튀기)

                                          └─ 데이터 양의 한계가 있을

                                                └─ Transfer learing

                                                      ┌─ Feature Extraction

                                                      └─ Fine Tuning

Image data 저장 방법

- Directory -> 대부분 방법을 사용

- HDF -> 연구용

- LMDB

데이터를 불러오는 방법

ImageDataGenerator + tf.data

ImageDataGenerator를 활용하여 저장, 전처리, augmentation을 사용할 수 있지만 속도가 느리기 때문에 ImageDataGenerator와 tf.data를 섞어 사용하면 훨씬 효율적이다

import tensorflow as tf 

idg = tf.keras.preprocessing.image.ImageDataGenerator() # 옵션을 사용해서 저장, augmentation을 사용할 수 있는 장점이 있다 
didg = idg.flow_from_directory('flower_photos/')
# Found 3670 images belonging to 5 classes.

next(didg)[0].dtype, next(didg)[0].shape
# (dtype('float32'), (32, 256, 256, 3))

next(didg)[1].dtype, next(didg)[1].shape
# (dtype('float32'), (32, 5))

# 인자에는 callable이 들어가야 한다 
train = tf.data.Dataset.from_generator(lambda: didg, output_types=(tf.float32, tf.float32), output_shapes=((None,256,256,3),(None,5))) 

train
# <FlatMapDataset shapes: ((None, 256, 256, 3), (None, 5)), types: (tf.float32, tf.float32)>
for i in train.take(1):
  print(i)
  (<tf.Tensor: shape=(32, 256, 256, 3), dtype=float32, numpy=
array([[[[124., 118.,  94.],
         [125., 123., 102.],
         [112., 106.,  92.],
         ...,
         [ 10.,  17.,   9.],
         [  8.,  18.,   9.],
         [  7.,  19.,   9.]],

        [[111., 106.,  84.],
         [122., 121., 100.],
         [115., 112.,  95.],
         ...,
         [ 10.,  17.,   9.],
         [  8.,  18.,   9.],
         [  8.,  18.,   9.]],

        [[108., 103.,  84.],
         [124., 123., 103.],
         [118., 116.,  95.],
         ...,
         [ 10.,  17.,   9.],
         [ 10.,  17.,   9.],
         [ 11.,  16.,   9.]],

        ...,

        [[163., 127., 129.],
         [157., 118., 119.],
         [ 26.,  37.,  20.],
         ...,
         [  6.,  11.,   5.],
         [  6.,  11.,   5.],
         [  5.,  10.,   4.]],

        [[ 94., 117.,  99.],
         [ 48.,  50.,  39.],
         [ 24.,  37.,  20.],
         ...,
         [  6.,  11.,   5.],
         [  6.,  11.,   5.],
         [  6.,  11.,   4.]],

        [[ 36.,  48.,  34.],
         [ 21.,  27.,  15.],
         [ 25.,  36.,  20.],
         ...,
         [  6.,  11.,   5.],
         [  9.,  12.,   5.],
         [  9.,  12.,   5.]]],


       [[[ 68.,  61.,  53.],
         [ 67.,  58.,  51.],
         [ 73.,  63.,  54.],
         ...,
         [ 92., 101.,  56.],
         [ 72.,  95.,  41.],
         [ 67.,  94.,  41.]],

        [[ 64.,  55.,  48.],
         [ 64.,  55.,  46.],
         [ 75.,  62.,  54.],
         ...,
         [ 95., 100.,  59.],
         [ 72.,  88.,  41.],
         [ 68.,  87.,  41.]],

        [[ 63.,  53.,  44.],
         [ 66.,  56.,  47.],
         [ 79.,  64.,  57.],
         ...,
         [103., 103.,  67.],
         [ 76.,  82.,  44.],
         [ 74.,  82.,  45.]],

        ...,

        [[133., 126.,  80.],
         [132., 128.,  81.],
         [138., 133.,  93.],
         ...,
         [ 99.,  99.,  91.],
         [ 92.,  89.,  80.],
         [ 84.,  80.,  71.]],

        [[132., 126.,  78.],
         [131., 127.,  79.],
         [134., 130.,  85.],
         ...,
         [104., 105., 100.],
         [ 92.,  92.,  84.],
         [ 85.,  82.,  73.]],

        [[131., 125.,  75.],
         [129., 126.,  75.],
         [131., 127.,  80.],
         ...,
         [114., 114., 112.],
         [102., 102.,  94.],
         [ 94.,  94.,  84.]]],


       [[[177., 196., 104.],
         [135., 159., 123.],
         [135., 159., 123.],
         ...,
         [ 19.,  43., 151.],
         [ 19.,  43., 151.],
         [ 19.,  43., 151.]],

        [[181., 202.,  99.],
         [139., 158., 112.],
         [139., 158., 112.],
         ...,
         [ 20.,  47., 150.],
         [ 20.,  47., 150.],
         [ 20.,  47., 150.]],

        [[165., 183.,  65.],
         [190., 210.,  97.],
         [190., 210.,  97.],
         ...,
         [ 20.,  48., 148.],
         [ 20.,  48., 148.],
         [ 20.,  48., 148.]],

        ...,

        [[201., 194.,  44.],
         [205., 200.,  48.],
         [205., 200.,  48.],
         ...,
         [101.,  97., 112.],
         [101.,  97., 112.],
         [120., 111., 138.]],

        [[196., 191.,  27.],
         [186., 181.,  15.],
         [186., 181.,  15.],
         ...,
         [ 80.,  80.,  82.],
         [ 80.,  80.,  82.],
         [ 53.,  52.,  50.]],

        [[201., 192.,  39.],
         [205., 196.,  39.],
         [205., 196.,  39.],
         ...,
         [125., 127., 114.],
         [125., 127., 114.],
         [ 87.,  89.,  68.]]],


       ...,


       [[[152., 178.,  68.],
         [147., 172.,  44.],
         [148., 174.,  51.],
         ...,
         [ 47.,  86.,   5.],
         [ 55.,  80.,  22.],
         [ 14.,  28.,   3.]],

        [[159., 178.,  52.],
         [155., 184.,  56.],
         [147., 182.,  62.],
         ...,
         [ 49.,  84.,   4.],
         [ 13.,  29.,   0.],
         [ 37.,  43.,  31.]],

        [[166., 203.,  37.],
         [158., 192.,  43.],
         [120., 153.,  20.],
         ...,
         [ 53.,  79.,  32.],
         [ 37.,  53.,   8.],
         [ 20.,  29.,   0.]],

        ...,

        [[106., 140., 105.],
         [ 83.,  96.,  68.],
         [ 33.,  21.,   9.],
         ...,
         [ 59.,  69.,  35.],
         [ 45.,  77.,  12.],
         [ 41.,  67.,  80.]],

        [[171., 208., 157.],
         [ 74., 107.,  50.],
         [  5.,   3.,  16.],
         ...,
         [ 63.,  71.,  34.],
         [ 30.,  58.,   7.],
         [ 36.,  64.,  68.]],

        [[191., 209., 171.],
         [104., 142.,  65.],
         [  2.,   4.,  25.],
         ...,
         [ 57.,  76.,  31.],
         [ 39.,  57.,  19.],
         [ 47.,  61.,  48.]]],


       [[[255., 255., 255.],
         [255., 255., 255.],
         [255., 255., 255.],
         ...,
         [255., 255., 255.],
         [255., 255., 255.],
         [255., 255., 255.]],

        [[255., 255., 255.],
         [255., 255., 255.],
         [255., 255., 255.],
         ...,
         [255., 255., 255.],
         [255., 255., 255.],
         [255., 255., 255.]],

        [[255., 255., 255.],
         [255., 255., 255.],
         [255., 255., 255.],
         ...,
         [255., 255., 255.],
         [255., 255., 255.],
         [255., 255., 255.]],

        ...,

        [[255., 255., 255.],
         [255., 255., 255.],
         [255., 255., 255.],
         ...,
         [255., 255., 255.],
         [255., 255., 255.],
         [255., 255., 255.]],

        [[255., 255., 255.],
         [255., 255., 255.],
         [255., 255., 255.],
         ...,
         [255., 255., 255.],
         [255., 255., 255.],
         [255., 255., 255.]],

        [[255., 255., 255.],
         [255., 255., 255.],
         [255., 255., 255.],
         ...,
         [255., 255., 255.],
         [255., 255., 255.],
         [255., 255., 255.]]],


       [[[ 92., 131., 164.],
         [ 91., 130., 163.],
         [ 91., 130., 163.],
         ...,
         [ 74., 110., 146.],
         [ 74., 110., 146.],
         [ 74., 110., 146.]],

        [[ 92., 131., 164.],
         [ 91., 130., 163.],
         [ 91., 130., 163.],
         ...,
         [ 74., 110., 146.],
         [ 74., 110., 146.],
         [ 74., 110., 146.]],

        [[ 94., 130., 164.],
         [ 94., 130., 164.],
         [ 94., 130., 164.],
         ...,
         [ 74., 110., 146.],
         [ 74., 110., 146.],
         [ 73., 109., 145.]],

        ...,

        [[ 94., 139., 181.],
         [ 94., 137., 180.],
         [ 94., 137., 180.],
         ...,
         [ 79., 107.,  92.],
         [ 79., 107.,  92.],
         [ 77., 107.,  83.]],

        [[ 99., 138., 179.],
         [ 99., 138., 179.],
         [ 99., 138., 179.],
         ...,
         [ 81., 111., 109.],
         [ 81., 111., 109.],
         [ 79., 110., 104.]],

        [[105., 140., 178.],
         [108., 142., 179.],
         [108., 142., 179.],
         ...,
         [ 84., 114., 124.],
         [ 84., 114., 124.],
         [ 80., 110., 118.]]]], dtype=float32)>, <tf.Tensor: shape=(32, 5), dtype=float32, numpy=
array([[0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 1., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 0.]], dtype=float32)>)

전처리 방식

1. inside model

- 학습과 전처리를 동시에 하기 때문에 시간이 많이든다

2. outside model

- 모델 밖에서 전처리를 경우 번거롭고 실수가 발생할 가능성이 높다

tf.keras.layers.experimental.preprocessing(전처리 layer) 사용할 경우 모델 안에서, 밖에서 유동적으로 사용 가능하다

=> 레이어는 함수이기 때문에 map 함께 쓸수 있다

경량화 모델

m1 = tf.keras.applications.MobileNet()

m1.summary()
Model: "mobilenet_1.00_224"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
conv1 (Conv2D)               (None, 112, 112, 32)      864       
_________________________________________________________________
conv1_bn (BatchNormalization (None, 112, 112, 32)      128       
_________________________________________________________________
conv1_relu (ReLU)            (None, 112, 112, 32)      0         
_________________________________________________________________
conv_dw_1 (DepthwiseConv2D)  (None, 112, 112, 32)      288       
_________________________________________________________________
conv_dw_1_bn (BatchNormaliza (None, 112, 112, 32)      128       
_________________________________________________________________
conv_dw_1_relu (ReLU)        (None, 112, 112, 32)      0         
_________________________________________________________________
conv_pw_1 (Conv2D)           (None, 112, 112, 64)      2048      
_________________________________________________________________
conv_pw_1_bn (BatchNormaliza (None, 112, 112, 64)      256       
_________________________________________________________________
conv_pw_1_relu (ReLU)        (None, 112, 112, 64)      0         
_________________________________________________________________
conv_pad_2 (ZeroPadding2D)   (None, 113, 113, 64)      0         
_________________________________________________________________
conv_dw_2 (DepthwiseConv2D)  (None, 56, 56, 64)        576       
_________________________________________________________________
conv_dw_2_bn (BatchNormaliza (None, 56, 56, 64)        256       
_________________________________________________________________
conv_dw_2_relu (ReLU)        (None, 56, 56, 64)        0         
_________________________________________________________________
conv_pw_2 (Conv2D)           (None, 56, 56, 128)       8192      
_________________________________________________________________
conv_pw_2_bn (BatchNormaliza (None, 56, 56, 128)       512       
_________________________________________________________________
conv_pw_2_relu (ReLU)        (None, 56, 56, 128)       0         
_________________________________________________________________
conv_dw_3 (DepthwiseConv2D)  (None, 56, 56, 128)       1152      
_________________________________________________________________
conv_dw_3_bn (BatchNormaliza (None, 56, 56, 128)       512       
_________________________________________________________________
conv_dw_3_relu (ReLU)        (None, 56, 56, 128)       0         
_________________________________________________________________
conv_pw_3 (Conv2D)           (None, 56, 56, 128)       16384     
_________________________________________________________________
conv_pw_3_bn (BatchNormaliza (None, 56, 56, 128)       512       
_________________________________________________________________
conv_pw_3_relu (ReLU)        (None, 56, 56, 128)       0         
_________________________________________________________________
conv_pad_4 (ZeroPadding2D)   (None, 57, 57, 128)       0         
_________________________________________________________________
conv_dw_4 (DepthwiseConv2D)  (None, 28, 28, 128)       1152      
_________________________________________________________________
conv_dw_4_bn (BatchNormaliza (None, 28, 28, 128)       512       
_________________________________________________________________
conv_dw_4_relu (ReLU)        (None, 28, 28, 128)       0         
_________________________________________________________________
conv_pw_4 (Conv2D)           (None, 28, 28, 256)       32768     
_________________________________________________________________
conv_pw_4_bn (BatchNormaliza (None, 28, 28, 256)       1024      
_________________________________________________________________
conv_pw_4_relu (ReLU)        (None, 28, 28, 256)       0         
_________________________________________________________________
conv_dw_5 (DepthwiseConv2D)  (None, 28, 28, 256)       2304      
_________________________________________________________________
conv_dw_5_bn (BatchNormaliza (None, 28, 28, 256)       1024      
_________________________________________________________________
conv_dw_5_relu (ReLU)        (None, 28, 28, 256)       0         
_________________________________________________________________
conv_pw_5 (Conv2D)           (None, 28, 28, 256)       65536     
_________________________________________________________________
conv_pw_5_bn (BatchNormaliza (None, 28, 28, 256)       1024      
_________________________________________________________________
conv_pw_5_relu (ReLU)        (None, 28, 28, 256)       0         
_________________________________________________________________
conv_pad_6 (ZeroPadding2D)   (None, 29, 29, 256)       0         
_________________________________________________________________
conv_dw_6 (DepthwiseConv2D)  (None, 14, 14, 256)       2304      
_________________________________________________________________
conv_dw_6_bn (BatchNormaliza (None, 14, 14, 256)       1024      
_________________________________________________________________
conv_dw_6_relu (ReLU)        (None, 14, 14, 256)       0         
_________________________________________________________________
conv_pw_6 (Conv2D)           (None, 14, 14, 512)       131072    
_________________________________________________________________
conv_pw_6_bn (BatchNormaliza (None, 14, 14, 512)       2048      
_________________________________________________________________
conv_pw_6_relu (ReLU)        (None, 14, 14, 512)       0         
_________________________________________________________________
conv_dw_7 (DepthwiseConv2D)  (None, 14, 14, 512)       4608      
_________________________________________________________________
conv_dw_7_bn (BatchNormaliza (None, 14, 14, 512)       2048      
_________________________________________________________________
conv_dw_7_relu (ReLU)        (None, 14, 14, 512)       0         
_________________________________________________________________
conv_pw_7 (Conv2D)           (None, 14, 14, 512)       262144    
_________________________________________________________________
conv_pw_7_bn (BatchNormaliza (None, 14, 14, 512)       2048      
_________________________________________________________________
conv_pw_7_relu (ReLU)        (None, 14, 14, 512)       0         
_________________________________________________________________
conv_dw_8 (DepthwiseConv2D)  (None, 14, 14, 512)       4608      
_________________________________________________________________
conv_dw_8_bn (BatchNormaliza (None, 14, 14, 512)       2048      
_________________________________________________________________
conv_dw_8_relu (ReLU)        (None, 14, 14, 512)       0         
_________________________________________________________________
conv_pw_8 (Conv2D)           (None, 14, 14, 512)       262144    
_________________________________________________________________
conv_pw_8_bn (BatchNormaliza (None, 14, 14, 512)       2048      
_________________________________________________________________
conv_pw_8_relu (ReLU)        (None, 14, 14, 512)       0         
_________________________________________________________________
conv_dw_9 (DepthwiseConv2D)  (None, 14, 14, 512)       4608      
_________________________________________________________________
conv_dw_9_bn (BatchNormaliza (None, 14, 14, 512)       2048      
_________________________________________________________________
conv_dw_9_relu (ReLU)        (None, 14, 14, 512)       0         
_________________________________________________________________
conv_pw_9 (Conv2D)           (None, 14, 14, 512)       262144    
_________________________________________________________________
conv_pw_9_bn (BatchNormaliza (None, 14, 14, 512)       2048      
_________________________________________________________________
conv_pw_9_relu (ReLU)        (None, 14, 14, 512)       0         
_________________________________________________________________
conv_dw_10 (DepthwiseConv2D) (None, 14, 14, 512)       4608      
_________________________________________________________________
conv_dw_10_bn (BatchNormaliz (None, 14, 14, 512)       2048      
_________________________________________________________________
conv_dw_10_relu (ReLU)       (None, 14, 14, 512)       0         
_________________________________________________________________
conv_pw_10 (Conv2D)          (None, 14, 14, 512)       262144    
_________________________________________________________________
conv_pw_10_bn (BatchNormaliz (None, 14, 14, 512)       2048      
_________________________________________________________________
conv_pw_10_relu (ReLU)       (None, 14, 14, 512)       0         
_________________________________________________________________
conv_dw_11 (DepthwiseConv2D) (None, 14, 14, 512)       4608      
_________________________________________________________________
conv_dw_11_bn (BatchNormaliz (None, 14, 14, 512)       2048      
_________________________________________________________________
conv_dw_11_relu (ReLU)       (None, 14, 14, 512)       0         
_________________________________________________________________
conv_pw_11 (Conv2D)          (None, 14, 14, 512)       262144    
_________________________________________________________________
conv_pw_11_bn (BatchNormaliz (None, 14, 14, 512)       2048      
_________________________________________________________________
conv_pw_11_relu (ReLU)       (None, 14, 14, 512)       0         
_________________________________________________________________
conv_pad_12 (ZeroPadding2D)  (None, 15, 15, 512)       0         
_________________________________________________________________
conv_dw_12 (DepthwiseConv2D) (None, 7, 7, 512)         4608      
_________________________________________________________________
conv_dw_12_bn (BatchNormaliz (None, 7, 7, 512)         2048      
_________________________________________________________________
conv_dw_12_relu (ReLU)       (None, 7, 7, 512)         0         
_________________________________________________________________
conv_pw_12 (Conv2D)          (None, 7, 7, 1024)        524288    
_________________________________________________________________
conv_pw_12_bn (BatchNormaliz (None, 7, 7, 1024)        4096      
_________________________________________________________________
conv_pw_12_relu (ReLU)       (None, 7, 7, 1024)        0         
_________________________________________________________________
conv_dw_13 (DepthwiseConv2D) (None, 7, 7, 1024)        9216      
_________________________________________________________________
conv_dw_13_bn (BatchNormaliz (None, 7, 7, 1024)        4096      
_________________________________________________________________
conv_dw_13_relu (ReLU)       (None, 7, 7, 1024)        0         
_________________________________________________________________
conv_pw_13 (Conv2D)          (None, 7, 7, 1024)        1048576   
_________________________________________________________________
conv_pw_13_bn (BatchNormaliz (None, 7, 7, 1024)        4096      
_________________________________________________________________
conv_pw_13_relu (ReLU)       (None, 7, 7, 1024)        0         
_________________________________________________________________
global_average_pooling2d (Gl (None, 1024)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 1, 1, 1024)        0         
_________________________________________________________________
dropout (Dropout)            (None, 1, 1, 1024)        0         
_________________________________________________________________
conv_preds (Conv2D)          (None, 1, 1, 1000)        1025000   
_________________________________________________________________
reshape_2 (Reshape)          (None, 1000)              0         
_________________________________________________________________
predictions (Activation)     (None, 1000)              0         
=================================================================
Total params: 4,253,864
Trainable params: 4,231,976
Non-trainable params: 21,888
_________________________________________________________________

 

tf.keras.utils.plot_model(m1,show_layer_names=True)

Convolution 4가지

1. Convolution

2. Grouped convolution

- 연산량을 줄이기 위해서 convolution연산을 분할 한다

3. 1x1 convolution

- 차원을 축소하고 non-linearity 특성을 활용하기 위해 elementwise 연산을 하는 convolution 방법

4. Depth-wise convolution

- 채널별로 나누어 convolution 연산을 한다

- 같은 depth끼리 연산한다

5. Depth-wise seperable convolution - Depthwise convolution + 1x1 convolution

Depth-wise convolution

Depth wise convolution은 각 채널별로 연산하는 방식으로

Standard convolution이 각 채널만의 spatial feature를 추출하는것이 불가능하기 때문에 고안해낸 방법이다

 

각 채널별로 연산하기 때문에 한번에 연산하는 파라미터의 수를 줄일 수 있는 장점이 있다

depth wise convolution은 일반 convolution보다 성능은 안좋지만 경량화에 목적이 있기 때문에 성능을 어느정도 감수해야 한다

가능하다면 성능을 유지하면서 경량화 하는 것을 최대 목표로 삼아야 한다

하지만 depth를 맞춰야 한다는 단점이 있다

Depth-wise seperable convolution

Depthwise convolution + 1x1 convolution

 

m2 = tf.keras.applications.MobileNetV2() # ResNet 방식을 따른다

tf.keras.utils.plot_model(m2 ,show_layer_names=True)

x = tf.keras.layers.BatchNormalization()
x.build((None,2))
x(tf.constant([[1. ,2.]]))

# <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.99950033, 1.9990007 ]], dtype=float32)>
x.weights # gamma, beta는 학습을 통해서 값을 찾는다 
# [<tf.Variable 'gamma:0' shape=(2,) dtype=float32, numpy=array([1., 1.], dtype=float32)>,
#  <tf.Variable 'beta:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>,
#  <tf.Variable 'moving_mean:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>,
#  <tf.Variable 'moving_variance:0' shape=(2,) dtype=float32, numpy=array([1., 1.], dtype=float32)>]
x.trainable_weights
# [<tf.Variable 'gamma:0' shape=(2,) dtype=float32, numpy=array([1., 1.], dtype=float32)>,
#  <tf.Variable 'beta:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>]
x.non_trainable_weights # 단순 계산 
# [<tf.Variable 'moving_mean:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>,
#  <tf.Variable 'moving_variance:0' shape=(2,) dtype=float32, numpy=array([1., 1.], dtype=float32)>]
input_ = tf.keras.Input((28,28,1))
x = tf.keras.layers.MaxPool2D(2,2)(input_)
model = tf.keras.models.Model(input_,x)

model.summary() # pooling은 연산만 하기 때문에 파라미터가 없다 / 학습과 관련이 없다 
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 1)         0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
data_augmentation = tf.keras.Sequential(
    [
        tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal"),
        tf.keras.layers.experimental.preprocessing.RandomRotation(0.1),
    ]
)

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()

prediction_layer = tf.keras.layers.Dense(1)

inputs = tf.keras.Input(shape=(160,160,3)) # 배치 크기만큼 입력을 받는다 
x = data_augmentation(inputs) # augmentation 사용 
x = preprocess_input(x)  # 전처리 
x = base_model(x, training=False) # training = False이기 때문에 내부적으로 평균과 표준편차를 구해서 배치마다 다른 평균과 표준편차를 사용할 수 있게 된다 
x = global_average_layer(x) 
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)   # prediction_layer만 backpropagation을 하고 이전 layer에서는 단순 계산만 한다 
model = tf.keras.Model(inputs, outputs)
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])
              

model.summary()
Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_7 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential_2 (Sequential)    (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv_2 (TFOpLambd (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract_2 (TFOpLamb (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d_2 ( (None, 1280)              0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 1280)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 2,225,153
Non-trainable params: 34,112
_________________________________________________________________
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

from tensorflow.keras.preprocessing import image_dataset_from_directory

BATCH_SIZE = 32
IMG_SIZE = (160, 160)
initial_epochs = 10

train_dataset = image_dataset_from_directory('cats_and_dogs_filtered/train',
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)

validation_dataset = image_dataset_from_directory('cats_and_dogs_filtered/validation',
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  image_size=IMG_SIZE)
                                                  
# Found 2000 files belonging to 2 classes.
# Found 1000 files belonging to 2 classes.
image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)
# (32, 5, 5, 1280)
history = model.fit(train_dataset,
                    epochs=initial_epochs,
                    validation_data=validation_dataset)

Epoch 1/10
63/63 [==============================] - 252s 4s/step - loss: 0.3724 - accuracy: 0.8405 - val_loss: 0.1214 - val_accuracy: 0.9730
Epoch 2/10
63/63 [==============================] - 162s 3s/step - loss: 0.1438 - accuracy: 0.9430 - val_loss: 0.0879 - val_accuracy: 0.9520
Epoch 3/10
63/63 [==============================] - 161s 3s/step - loss: 0.1267 - accuracy: 0.9525 - val_loss: 0.0871 - val_accuracy: 0.9580
Epoch 4/10
63/63 [==============================] - 161s 3s/step - loss: 0.0929 - accuracy: 0.9590 - val_loss: 0.0485 - val_accuracy: 0.9770
Epoch 5/10
63/63 [==============================] - 162s 3s/step - loss: 0.0856 - accuracy: 0.9660 - val_loss: 0.3747 - val_accuracy: 0.9060
Epoch 6/10
63/63 [==============================] - 160s 3s/step - loss: 0.0874 - accuracy: 0.9655 - val_loss: 0.0987 - val_accuracy: 0.9780
Epoch 7/10
63/63 [==============================] - 160s 3s/step - loss: 0.0817 - accuracy: 0.9755 - val_loss: 0.0573 - val_accuracy: 0.9760
Epoch 8/10
63/63 [==============================] - 160s 3s/step - loss: 0.0556 - accuracy: 0.9830 - val_loss: 0.0556 - val_accuracy: 0.9810
Epoch 9/10
63/63 [==============================] - 160s 3s/step - loss: 0.0603 - accuracy: 0.9765 - val_loss: 0.0524 - val_accuracy: 0.9810
Epoch 10/10
63/63 [==============================] - 160s 3s/step - loss: 0.0663 - accuracy: 0.9815 - val_loss: 0.0380 - val_accuracy: 0.9840
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

Tensorflow에서는 학습할 때 validation dataset을 포함하면

평가시 dropout과 batch normalization을 하지 않고 하기 때문에 validation loss가 더 낮게 나오는 경우가 많이 발생한다

(이런 경우는 underfitting 문제가 아닐 수 있다)

Fine Tuning 고급 테크닉

Fine tuning

기존에 학습된 모델을 기반으로 아키텍쳐를 새로운 목적에 맞게 변형하고 이미 학습된 모델 Weights로 부터 학습을 업데이트하는 방법을 말한다.

1. 어느정도 성능이 확보된 모델을 만든다

2. 일부 Freezon (trainable=False)시키고 느린 learning rate을 통해 나머지를 재학습 시킨다

3. 새 학습 데이터가 기존 학습 데이터와 큰 차이가 나면 안된다

4. 새로 학습할 레이어 초기화 하면 기존 능력이 없어지기 때문에 성능이 확보된 모델 가중치를 그대로 사용한다

최상위 층 고정 해제하기

# Fine tuning이기 때문에 이미 학습된 모델로 사용한다 

base_model.trainable = True # 학습할 수 있도록 우선 만들고 

fine_tune_at = 100

for layer in base_model.layers[:fine_tune_at]: # 맨 마지막 층을 제외하고 고정시킨다 
  layer.trainable = False                      # 맨 마지막 층은 결과를 내는 층 (개, 고양이 분류)

Fine tuning 하는 간단한 방식

1. 가능한 마지막 층 부터 재학습을 한다 (마지막 층들은 전문화된 특징을 분류한다)

1. learning rate 바꾸기

2. epoch 수 줄이기

 

Catastrophic forgetting

- 다른 종류의 데이터를 학습하면 이전에 학습했던 데이터에 대한 성능이 현저하게 떨어지는 문제

- 이 현상은 이전 학습 dataset과 새로운 학습 dataset 사이에 연관성이 있더라도 이전 dataset에 대한 정보를 대량으로 손실하는 문제가 발생한다

 

Semantic shift

- 기존에 학습된 모델에 새로운 형태의 데이터가 추가 학습 되었을 때 가중치가 변하면서 의미가 변해버리는 현상

Incremental learning

전혀 다른 특징을 가진 데이터 셋을 학습하여 내가 가진 모델에서 전혀 다른 특징을 가진 데이터를 분류할 수 있는 모델로 능력을 키우는 것

 

 

반응형

+ Recent posts