Deep Learning

[Optimization] APEX ASP (Automatic SParsity)

ju_young 2022. 7. 30. 22:47
728x90

ASP는 모델의 학습, 추론 속도를 높이고 메모리 효율과 accuracy를 유지하는 것이 목적, 이를 위해 weight를 saprse하게 연산할 수 있게 하는 mask를 구하는 방식이다.

 

NVIDIA ampere gpu 에서는 2:4 fine grained sparsity가 지원되어있다. 이 방식으로 pruning을 지원하다.

 

2:4 fine grained sparsity는 아래 그림을 확인하면 간단하게 이해할 수 있다.

2:4 fine grained sparsity

2:4 fine grained sparsity는 위 그림과 같이 1D 또는 2D (2D도 가능하다)에서 작은 값 2개를 탈락시키는 방식으로 동작한다. 이렇게 생선된 마스크를 사용해서 weight는 이렇게 pruned 되어있는 살이있는 value들만 모아서 아래 그림의 오른쪽 Sparse MxNxK GEMM에 있는 Non-zero data vlaues 처럼 50%만 연산을 수행한다.

그리고 zero value를 표시하는 마스크(fig.2의 Sparse MxNxK GEMM의 2-bits indices)는 2bits value mask인데 이것을 weight와 함께 집어넣고 input에 마스킹을 하는 방식으로 반절의 데이터만 연산에 참여할 수 있게한다.

 

이렇게 zero value mask를 만들고 training이나 inference를 수행할때 50%의 연산만으로 많은 이득을 낼 수 있다.

 

ASP 적용은 다음과 같이 상당히 간단하다.

ASP.prune_trained_model(model, optimizer)

또한 prune_trained_model 함수의 내부는 다음과 같이 구현되어 있다.

ASP.init_model_for_pruning(model,
                mask_calculator="m4m2_1d",
                verbosity=2,
                whitelist=[torch.nn.Linear, torch.nn.Conv2d],
                allow_recompute_mask=False)

ASP.init_optimizer_for_pruning(optimizer)

ASP.compute_sparse_masks()

init_model_for_puning 함수에서는 모델의 initialize를 진행해준다. 여기서 mask_calculator라는 부분이 있는데 여기에 세팅하는 variable 들이 있는데 사실 asp 함수를 사용하게 되면 고정되어있다. 이거에 대한 의미를 말하면 "m4n2"는 2:4 fine grained sparsity mask를 구해라 라는 옵션이고 m4n2_2d도 사용할 수 있다. 하지만 2D보다 1D로 적용하는 것을 추천한다고 한다.

 

모델이 굉장히 작다면 2:4 sparse pruning을 반영되기 힘들고 학습이 잘 안될 수 있을 것이다. 이것을 해결하기 위해서 permutation으로 magnitude를 최대한 높은 것들을 살릴 수 있게끔한 후 sparsity를 적용하고 학습하면 수렴이되는 것을 발견했다고 한다. 사실 이 permutation 부분은 기본적으로 true가 되어있기 때문에 신경쓸 필요가 없다.

 

이제 ASP를 적용하여 학습을 진행하는 과정을 간단히 보면 다음과 같다.

- Step 1: Dense Training

- Step 2: Puning(ASP) 적용

- Step 3: Sparse Retraining

 

Step 3는 간단히 생각하면 Dense Training을 했을 때와 비슷한 성능이 나올 때까지 학습을 진행하면 되는 것이라고 이해하면 될 것 같다.

 

ASP 적용 후 왜 Train을 또 할까라는 의문이 든다면 아래 그림을 한 번 보자.

가운데 그림이 ASP를 적용했을 때의 weight 값들이고 아래 그림이 Retrain을 했을 때의 weight 값들이다. 두 그림의 차이를 보면 Retrain을 하면 중요한 것들은 더 weight가 더해진 것을 확인할 수 있다.

 

ASP를 실제로 사용해본 결과 약간의 삽질이 필요하다. 간단한 architecture를 가지는 모델은 적용이 쉽지만 Transformer 계열의 모델들에는 적용이 쉽지않다. 구체적으로 말하면 일반적으로 사용하는 pytorch 코드, 함수 사용에 잘 맞지 않는다. 마치 Torchscript Conversion할 때 오류가 상당히 많이 나는 느낌과 비슷하다.

 

개인적으로 Torchscript Conversion일 잘 되는 모델 코드는 ASP 적용도 무난하게 잘 되는 것 같다. 그리고 특정 cls가 key 값으로 없다는 에러가 뜬다면 ass_sparse_attributes 함수에 "sparse_parameters = ['weight']"로 수정하면 해결되었다. 왜냐면 어차피 sparse_parameter_list의 value 값들이 전부 ['weight']라고 되어있기 때문이다.

728x90