동삼이의 노트북
MNIST 분류 (1) - 분류 성능 평가 본문
MNIST 데이터베이스는 (Modified National Institute of Standards and Technology database) 손으로 쓴 70000개의 작은 숫자 이미지로 이루어진 대형 데이터베이스이다. 이 데이터셋은 학습용으로 아주 많이 사용되며 이번 프로젝트에서는 이 MNIST 데이터셋을 이용하여 분류 모델을 만들어 보도록 하자.
사이킷런에는 MNIST 데이터셋을 내려받을 수 있는 헬퍼 함수들이 존재한다. 보통 사이킷런에서 읽어 들인 데이터셋들은 일반적으로 비슷한 딕셔너리 구조를 갖고 있다.
- 데이터셋을 설명하는 DESCR키
- 샘플이 하나의 행, 특성이 하나의 열로 구성된 배열을 가진 data키
- 레이블 배열을 담은 target키
MNIST 데이터 셋은 7만개의 이미지가 존재하며 각 이미지에는 784개의 특성이 존재한다. 이미지가 28 x 28 픽셀이기 때문이다. 각 특성은 단순히 0(흰색) 부터 255 (검은색) 까지의 픽셀 강도를 나타낸다.
0번째 행을 추출하여 28 * 28 배열로 크기를 바꾼 후, imshow() 함수를 이용하여 픽셀 이미지를 나타낸다. 이미지 결과 5 로 보이는 문자 이미지가 출력이 되었다.
같은 행의 레이블 값을 확인해본 결과, 해당 이미지는 5 인것을 알 수있다. MNIST는 수기로 그린 작은 숫자들의 이미지 데이터셋이기 때문에 문자열 형태인 레이블 값들을 전부 정수형으로 바꿔준다.
MNIST에서 추출한 숫자 이미지는 위와 같다. 모두 수기로 작성한 숫자이기 때문에 같은 숫자라 할지라도 모습이 조금씩 다른 것을 알 수 있다.
본격적인 분류 작업에 들어가기 앞서서 테스트 셋을 만들어 따로 떼주도록 하자
이진 분류기 훈련
먼저 문제를 단순화해서 하나의 숫자, 5에 관해서만 식별해보도록 한다. '5-감지기'를 만들어내어 '5'와 '5가 아님' 두개의 클래스를 구분할 수있는 이진분류기를 만들어 보자.
레이블이 5면 True, 5가 아니면 False인 타겟 벡터를 생성했다. 이제 사이킷 런의 SGDClassifier 클래스를 사용하여 확률적 경사 하강법(Stochastic Gradient Descent) 분류기로 분류를 해보자. 이 분류기는 매우 큰 데이터셋을 효율적으로 처리하는 장점이 있다.
SGDClassifier를 이용하여 MNIST 이미지 데이터가 레이블이 5인지 5가 아닌지를 학습시켰다.
그 후 위에서 확인했었던 숫자 5의 이미지 행렬을 갖고 있던 데이터를 입력시켜 분류를 진행해본 결과 True라는 값을 내놓았다. 그렇다면 이 모델의 전체적인 성능은 어떨까?
성능 측정
cross_val_score() 함수를 이용하여 폴드가 3개인 k-hold 교차검증을 사용하여 위 모델을 평가해보았다. 각 폴드에 대한 정확도(accuracy)가 전부 95%이상이다. 과연 이 정확도가 '정확히' 위 모델의 정확도를 나타내는 지표일까?
어떤 입력값이 들어오든 False 즉, 5가 아님을 예측하는 임의의 모델을 만들어 본 후 성능평가를 실시해보도록 하자.
아이러니 하게도 모델은 모든 예측을 5가 아님으로 했음에도 정확도(accuracy)는 여전히 90%를 넘는 높은 성능을 보여준다. 위와 같은 점수가 나온 이유는 전체 데이터셋에서 10%정도만 숫자 5이기 때문에 무조건 '5가 아님' 으로 예측하면 정답을 맞출 확률이 90%이기 때문이다. 위 예시는 정확도(accuracy)를 분류기의 성능 측정 지표로 선호하지 않는 이유를 보여준다. 특히 불균형한 데이터셋일 경우에 더욱 그렇다.
Confusion Matrix
분류기의 성능을 평가하는 데에는 정확도 보다 Confusion Matrix(오차 행렬)를 조사하는 것이 더욱 좋다. 기본적인 아이디어는 클래스 A의 샘플이 클래스 B로 분류되는 횟수를 세는 것이다.
간단히 말해, True를 True로 예측하느냐, True를 False로 예측하느냐, False를 True로 예측하느냐, False를 False로 예측하느냐 로 행렬을 만들어서 성능을 평가하는 것이다. 위의 '5-감지기' 모델의 경우 True를 True로 예측하는 것에 대한 확률만 구했기 때문에 전부 '5가 아님'으로 답을 내놓아도 90% 이상의 정확도를 나타내었다. 하지만 위 네가지 경우를 모두 확인해본다면 성능은 크게 달라질 것이다.
위 코드에서 사용한 cross_val_predict는 k-hold 교차 검증을 수행하지만 cross_val_score와는 달리 점수를 반환하지 않고 각 테스트 폴드에서 얻은 예측 값을 반환한다. 이제 이를 통해 confusion_matrix() 함수를 이용하여 confusion matrix를 구해보자.
Confusion matrix의 행은 실제 클래스를 나타내고, 열은 예측한 클래스를 나타낸다. 이 행렬의 첫 번째 행은 '5가 아님' 에 대한 것으로 (negative class) 53,892개를 '5가 아님' 으로 정확하게 분류했고 이를 True negative라고 한다. 나머지 687개는 '5'라고 잘못 분류했다. 이를 False positive라고 한다. 두 번째 행은 '5'에 대한 것으로(positive class) 1891개를 '5가 아님'으로 잘못 분류했고(False negative) 3530개를 '5'라고 정확하게 분류했다.(True positive) 만약 완벽한 분류기라면, Confusion matrix의 왼쪽 위~오른쪽 아래 대각선 값이 0이 될 것이다.
이런 컨퓨전 매트릭스는 많은 정보를 주지만 좀 더 요약된 지표가 필요하다면 정밀도(precision)를 사용할 수도 있다.
- 정밀도 = TP / (TP + FP)
이 정밀도의 경우 분류기가 확실한 양성 샘플 하나만 예측한다면 완벽한 정밀도 1 을 얻을 수 있지만 다른 모든 양성 샘플을 무시하기 때문에 그렇게 유용하진 않다. 그렇기에 정밀도는 재현율(recall)이라는 또 다른 지표를 사용하는게 일반적이다. 재현율은 분류기가 정확하게 감지한 양성 샘플의 비율을 의미하며 민감도(sensitivity) 또는 True positive rate라고도 한다.
- 재현율 = TP /(TF + FN)
우리가 생성한 '5-검지기'의 정밀도와 재현율은 위와 같다. 정확도에서 봤던 만큼 좋은 점수가 나오진 않는다. 게다가 전체 숫자5 중에서 65%만 정확한 것을 알 수 있다. 이러한 정밀도와 재현율을 F1 점수라고 하는 하나의 숫자로 만들어 확인할 수도 있다.
즉 우리의 분류기 모델은 0.73 정도의 f1 score를 갖는 것을 알 수 있다. 이 f1 score가 항상 절대적인 것은 아니다. 상황에 따라 재현율이 중요할 수도 있고 정밀도가 중요할 수도 있다. 예를 들어 어린 아이에게 나쁜 동영상을 걸러내는 분류기를 훈련할 때, 재현율은 높으나 정말 나쁜 동영상이 몇개 노출 되는 것 보다 좋은 동영상이 많이 제외되더라도 안전한 것들만 노출시키는 (높은 정밀도) 분류기를 선호할 것이다. 다른 예로 감시 카메라를 통해 도둑을 잡아내는 분류기를 훈련시킨다고 할 때, 분류기의 재현율이 99%라면 정확도가 30%가 되더라도 괜찮을 지도 모른다. 경비원이 종종 잘못된 호출을 받겠지만 거의 모든 도둑을 잡을 것이다. 즉 이 두가지를 동시에 얻을 수는 없다. 정밀도와 재현율은 트레이드 오프 관계에 놓여있다.
'Projects' 카테고리의 다른 글
MNIST 분류 (3) - ROC curve (0) | 2020.11.18 |
---|---|
MNIST 분류 (2) - 정밀도와 재현율 (0) | 2020.11.17 |
캘리포니아 주택 가격 예측 (6) - 하이퍼 파라미터 튜닝 (0) | 2020.11.15 |
캘리포니아 주택 가격 예측 (4) - 모델링 (0) | 2020.11.13 |
캘리포니아 주택 가격 예측 (3) - Feature Engineering (0) | 2020.11.09 |