본문 바로가기
SW 개발/Data 분석 (RDB, NoSQL, Dataframe)

Keras 기계 학습 모델의 저장과 로드 방법 (Sample code)

by Kibua20 2021. 9. 12.

기계 학습 모델의 학습 모델을 저장과 로드하는 방법입니다. 데이터가 큰 기계 학습인 경우 학습 시간이 오래 걸리고, 학습 모델과 예측의 pipeline을 분리하기 위해서 학습의 모델을 저장하고 로드하는 기능이 필요합니다.   Keras의 모델의 저장과 로드는 https://www.tensorflow.org/guide/keras/save_and_serialize에서 자세히 설명되어 있고, 본 블로그도 해당 페이지 내용과 확인 결과를 기초로 작성했습니다.

 

Keras 모델은 아래와 같은 요소를 저장하고 있습니다.  Keras에서는 모델 저장 API를 사용하면  모델 정보 전체를 저장 또는 가중치 값을 저장합니다.

  • 학습 모델: 모델에 포함된 layer 구성 및 연결 방법
  • 모델의 상태: 가중치 값의 집합
  • 모델 컴파일 상태: optimizer 상태
  • 모델의 loss 및 metric: 모델을 컴파일하거나 add_loss() 또는 add_metric()을 호출하여 정의된 손실 및 메트릭의 집합

Keras에서 model saved와 Load 명령어는 아래와 같이 model.save()와 load_model() 사용할 수 있습니다. 

 

from tensorflow import keras

 

model = keras.Model(..., ...)

model.save("my_model")

reconstructed_model = keras.models.load_model("my_model")

 

 

1.  전체 학습 모델 저장 및 로딩

Keras에서는 전체 모델을 저장할 수 있습니다. 전체 모델을 디스크에 저장하는 데 사용할 수 있는 두 형식은 TensorFlow SavedModel 형식 이전 Keras H5 형식입니다. TensorFlow SavedModel 이 권장하는 형식이면 API의 기본 값입니다. 

 

model.save()에 의해서 저장하는 내용은 다음과 같습니다. saved_model.pb파일에는 모델 구성 및 Optimizer, loss, metric 가 저장되고, variables 폴더에는 가중치가 저장됩니다.

전체 학습 모델 저장 내용

#!/usr/bin/python3
# -*- coding: utf-8 -*-

import numpy as np

from tensorflow import keras

# --------------------------------------------------------------------------------------------------------
# Save keras model: code sample from https://www.tensorflow.org/guide/keras/save_and_serialize?hl=ko
def save_model():
    # Create a simple model.
    inputs = keras.Input(shape=(32,))
    outputs = keras.layers.Dense(1)(inputs)
    model = keras.Model(inputs, outputs)
    model.compile(optimizer="adam", loss="mean_squared_error")
    
    # Train the model.
    test_input = np.random.random((128, 32))
    test_target = np.random.random((128, 1))

    model.fit(test_input, test_target)

    # Calling `save('my_model')` creates a SavedModel folder `my_model`.
    model.save("my_model")

# --------------------------------------------------------------------------------------------------------
# run the app.
def load_model():
    # It can be used to reconstruct the model identically.
    reconstructed_model = keras.models.load_model("my_model")

    # Train the model.
    test_input = np.random.random((128, 32))
    test_target = np.random.random((128, 1))

    # The reconstructed model is already compiled and has retained the optimizerstate, so training can resume:
    reconstructed_model.fit(test_input, test_target)

    print ('Reconstruced Model:')
    reconstructed_model.summary()

# --------------------------------------------------------------------------------------------------------
# run the app.
if __name__ == "__main__":
    save_model()
    load_model()

 

model.save() 호출 시 저장 파일 이름에서 확장자를 h5로 지정하면 Keras H5형식으로 저장됩니다.  Keras H5 형식으로 저장하면 my_model  폴더 대신 하나의  my_model.h5파일로 저장됩니다. SavedModel형식 대비해서 model.add_loss()와  model.add_metric()에 의해서 추가된 손실 및 메트릭은 저장되지 않습니다. 

 

# Keras H5 형식 저장

model.save("my_model.h5")

# Keras H5 형식으로 로드

reconstructed_model = keras.models.load_model("my_model.h5")

 

전체 모델을 저장하는 방법뿐 아니라, 아키텩쳐만 저장, 가중치 저장에 대한 API 도 지원하며, 상세 내용은 링크를 참조해주세요.

 

Sample code

Sample code는 GitHub에도 올렸으면 13_keras_model_save.py 파일을 참고해주세요. Model 저장과 로드 뿐 아니라  다양한 test code가 반영되어 있습니다.

 

관련 글:

[SW 개발/Data 분석 (RDB, NoSQL, Dataframe)] - k-mean Clustering 알고리즘 개념 및 Sample code

[SW 개발/Data 분석 (RDB, NoSQL, Dataframe)] - Keras를 이용한 다중 클래스 분류: softmax regression (Sample code)

[SW 개발/Data 분석 (RDB, NoSQL, Dataframe)] - Python plotly와 dash를 이용한 Web 기반 data visualization (sample code)

[SW 개발/Data 분석 (RDB, NoSQL, Dataframe)] - Python Dash를 활용한 Web App 구현 및 시계열 데이터 Visualization (Sample code)

[SW 개발/Data 분석 (RDB, NoSQL, Dataframe)] - Python Keras를 이용한 로직스틱 회귀 분석(logistics regression) 예제- Wine Quality 분석(Sample code)

[SW 개발/Data 분석 (RDB, NoSQL, Dataframe)] - Python Keras를 이용한 다중회귀(Multiple regression) 예측 (Sample code)

[SW 개발/Data 분석 (RDB, NoSQL, Dataframe)] - Python Keras를 이용한 Linear regression 예측 (Sample code)

[SW 개발/Python] - Python Decorator를 이용한 함수 실행 시간 측정 방법 (Sample code)

[SW 개발/Data 분석 (RDB, NoSQL, Dataframe)] - Random Number를 가지는 Pandas Dataframe 생성 (좋은 code와 나쁜 code 비교)

 




댓글