Machine Learning/Coding

[실습] k-최근접 이웃(k-Nearest Neighbor, kNN)

ju_young 2021. 5. 22. 01:51
728x90

[https://github.com/wikibook/machine-learning]에서 다운로드 받은 농구선수에 대한 데이터를 사용하여 kNN 알고리즘을 적용해보는 실습을 하였다.

 

목표는 임의의 농구선수의 포지션을 예측하는 것이다.

데이터 획득

import pandas as pd

df = pd.read_csv(./data/basketball_stat.csv)
df.head()

colab에서 실행하였고 파일을 드라이브에 넣었기 때문에 경로를 "./data/basketball_stat.csv"로 지정해준 것이다.

 

[출력]

각 데이터 속성 값의 의미는 다음과 같다.

  • Player = 선수 이름
  • Pos = 포지션
  • 3P = 한 경기 평균 3점슛 성공 횟수
  • 2P = 한 경기 평균 2점슛 성공 횟수
  • TRB = 한 경기 평균 리바운드 성공 횟수
  • AST = 한 경기 평균 어시스트 성공 횟수
  • STL = 한 경기 평균 스틸 성공 횟수
  • BLK = 한 경기 평균 블로킹 성공 횟수

# 포지션의 개수 확인
df.Pos.value_counts()

[출력]

SG    50
C     50
Name: Pos, dtype: int64

 

데이터 시각화

import matplotlib.pyplot as plt
import seaborn as sns

#스틸, 2점슛 데이터 시각화
sns.lmplot('STL', '2P', data = df, fit_reg = False, # x축, y축, 데이터, 노 라인
           scatter_kws={"s":150}, # 좌표 상의 점 크기
           markers=["o", "x"], #모양 지정
           hue="Pos") #예측값

#타이틀
plt.title('STL and 2P in 2d plane')

#어시스트, 2점슛 데이터 시각화
sns.lmplot('AST', '2P', data = df, fit_reg = False, # x축, y축, 데이터, 노 라인
           scatter_kws={"s":150}, # 좌표 상의 점 크기
           markers=["o", "x"], #모양 지정
           hue="Pos") #예측값

#타이틀
plt.title('AST and 2P in 2d plane')

[출력]

스틸, 2점슛과 어시스트, 2점슛 데이터를 시각화 했을 때 슈팅가드와 센터의 경계가 너무 근접해서 분류하기 힘들다는 점을 쉽게 판단할 수 있다.


#블로킹, 3점슛 데이터 시각화
sns.lmplot('BLK', '3P', data = df, fit_reg = False, # x축, y축, 데이터, 노 라인
           scatter_kws={"s":150}, # 좌표 상의 점 크기
           markers=["o", "x"], #모양 지정
           hue="Pos") #예측값

#타이틀
plt.title('BLK and 3P in 2d plane')

#리바운드, 3점슛 데이터 시각화
sns.lmplot('TRB', '3P', data = df, fit_reg = False, # x축, y축, 데이터, 노 라인
           scatter_kws={"s":150}, # 좌표 상의 점 크기
           markers=["o", "x"], #모양 지정
           hue="Pos") #예측값

#타이틀
plt.title('TRB and 3P in 2d plane')

[출력]

블로킹과 3점슛, 리바운드와 3점슛으로 시각화하면 데이터의 구분이 확실한 것을 확인할 수 있다.

 

데이터 다음기

위에서 확인했듯이 분별력이 없는 2점슛과 어시스트, 스틸 속성을 제거해버린다.

#2점슛, 어시스트, 스틸 삭제
df.drop(['2P', 'AST', 'STL'], axis=1, inplace=True)

df.head()

[출력]

 

데이터 나누기

학습 데이터와 테스트 데이터로 분리한다.

from sklearn.model_selection import train_test_split

#80%는 학습 데이터, 20%는 테스트 데이터로 분리
train, test = train_test_split(df, test_size=0.2)

#학습 데이터, 테스트 데이터 개수 확인
print(train.shape[0])
print(test.shape[0])

[출력]

80
20

 

최적의 kNN 파라미터 찾기

cross_val_score를 사용하여 교차검증을 실행한다. 이때 k = 3부터 학습 데이터 절반의 크기까지 수행한다.

#kNN 라이브러리 import
from sklearn.neighbors import  KNeighborsClassifier
from sklearn.model_selection import cross_val_score # k-fold 교차검증

#최적의 k를 찾기 위해 교차 검증을 수행할 k의 범위를 3부터 학습 데이터 절반까지 지정
max_k_range = train.shape[0] // 2
k_list = []
for i in range(3, max_k_range, 2): #홀수
  k_list.append(i)

cross_validation_scores = [] # 각 k의 검증 결과 점수들
x_train = train[['3P', 'BLK', 'TRB']]
y_train = train[['Pos']]

#교차 검증(10-fold)을 각 k를 대상으로 수행해 검증 결과를 저장
for k in k_list:
  knn = KNeighborsClassifier(n_neighbors=k)
  scores = cross_val_score(knn, x_train, y_train.values.ravel(), cv=10,
                           scoring='accuracy')
  cross_validation_scores.append(scores.mean())

cross_validation_scores

[출력]

[0.95,
 0.95,
 0.95,
 0.9375,
 0.925,
 0.925,
 0.925,
 0.925,
 0.9125,
 0.9125,
 0.8875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.8625,
 0.8625,
 0.8625,
 0.8625]

 #k에 따른 정확도 시각화
plt.plot(k_list, cross_validation_scores)
plt.xlabel('the number of k')
plt.ylabel('Accuracy')
plt.show()

[출력]


#가장 예측율이 높은 k를 선정
k = k_list[cross_validation_scores.index(max(cross_validation_scores))]
print("The best number of k : " + str(k))

[출력]

The best number of k : 3

 

모델 테스트

from sklearn.metrics import accuracy_score

knn = KNeighborsClassifier(n_neighbors=k)

x_train = train[['3P', 'BLK', 'TRB']]
y_train = train[['Pos']]

#knn 모델 학습
knn.fit(x_train, y_train.values.ravel())

x_test = test[['3P', 'BLK', 'TRB']]
y_test = test[['Pos']]

#테스트 시작
pred = knn.predict(x_test)

#모델 예측 정확도(accuracy) 출력
print("accuracy : " + str(accuracy_score(y_test.values.ravel(), pred)))

[출력]

accuracy : 0.85

#실제값과 예측값 비교
comparison = pd.DataFrame({'prediction':pred, 'ground_truth':y_test.values.ravel()})
comparison

[출력]

728x90