Temp

[Streamlit] @st.cache를 사용하여 Torchscript load하기

ju_young 2021. 12. 17. 01:19
728x90

일반적으로 다음과 같이 @st.cache으로 사용할 경우 hash_funcs을 사용하라면 에러 메시지가 뜬다.

@st.cache
def load_model(weights=os.path.join(MODEL_DIR_PATH, 'best.torchscript.pt')):
    # Load model
    w = weights
    model = torch.jit.load(w)
    return model

Error Message

다음 소스코드를 보면 특정 type의 경우에는 return to_bytes를 한다. 아마 torchscript의 module은 streamlit에서 작성되지 않은 것 같다.

https://github.com/streamlit/streamlit/blob/develop/lib/streamlit/legacy_caching/hashing.py#L602

 

이런 경우 다음과 같이 hash_funcs를 사용해서 해결할 수 있다. lambda_:None으로 값을 지정하면 해당 object의 해싱을 사용하지 않는다는 의미이다. 다른 방법으로는 Custom한 hash 함수를 값으로 지정할 수 있다고 한다.

@st.cache(hash_funcs={torch.jit._script.RecursiveScriptModule : lambda _: None})
def load_model(weights=os.path.join(MODEL_DIR_PATH, 'best.torchscript.pt')):
    # Load model
    w = weights
    model = torch.jit.load(w)
    return model

그 결과 caching을 했을때와 하지 않았을 때의 시간차를 다음과 같이 확인 할 수 있었다. (모델을 load하는 시간만 확인함)

[ref]

https://docs.streamlit.io/knowledge-base/using-streamlit/caching-issues

728x90