Temp
[Streamlit] @st.cache를 사용하여 Torchscript load하기
ju_young
2021. 12. 17. 01:19
728x90
일반적으로 다음과 같이 @st.cache으로 사용할 경우 hash_funcs을 사용하라면 에러 메시지가 뜬다.
@st.cache
def load_model(weights=os.path.join(MODEL_DIR_PATH, 'best.torchscript.pt')):
# Load model
w = weights
model = torch.jit.load(w)
return model
다음 소스코드를 보면 특정 type의 경우에는 return to_bytes를 한다. 아마 torchscript의 module은 streamlit에서 작성되지 않은 것 같다.
https://github.com/streamlit/streamlit/blob/develop/lib/streamlit/legacy_caching/hashing.py#L602
이런 경우 다음과 같이 hash_funcs를 사용해서 해결할 수 있다. lambda_:None으로 값을 지정하면 해당 object의 해싱을 사용하지 않는다는 의미이다. 다른 방법으로는 Custom한 hash 함수를 값으로 지정할 수 있다고 한다.
@st.cache(hash_funcs={torch.jit._script.RecursiveScriptModule : lambda _: None})
def load_model(weights=os.path.join(MODEL_DIR_PATH, 'best.torchscript.pt')):
# Load model
w = weights
model = torch.jit.load(w)
return model
그 결과 caching을 했을때와 하지 않았을 때의 시간차를 다음과 같이 확인 할 수 있었다. (모델을 load하는 시간만 확인함)
[ref]
https://docs.streamlit.io/knowledge-base/using-streamlit/caching-issues
728x90