[Machine Learning] K-NN Algorithms_Part 2
<Introduction>
k-nearset neighbors 알고리즘에는 회귀 형태도 존재한다. k-neighbors regression의 경우 주변의 가장 가까운 K개의 샘플을 통해 값을 예측하는 방식이다. 주변에 가까운 값들의 데이터 값을 모두 더해서 n_neighbors의 값으로 나누어 평균을 구해주는 방식이다. 머신러닝에서 알고리즘을 배울 때 가장 좋은 방법은 예시를 보는 것이기 때문에 지금 바로 하나의 문제 풀어보도록 하자.
Question 1. 선형적인 구조를 띄는 데이터 셋이 주어졌을 때 n_neighbors가 얼마일때 가장 좋은 효율을 내는지와, 해당 값을 적용시켜서 87에서의 결과 값을 제시하고, 결과 값이 어떻게 도출되었는지를 증명하라.
한가지 문제에 풀어야 할 문제가 총 3가지로 나눠진다. n_neighbors가 얼마일때 가장 좋은 효율을 내는지와 87의 결과값, 해당 결과 값이 나오는 이유를 증명하는 것이다.
Part 1. 먼저 필요한 파이썬 라이브러리를 불러오자.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split # 데이터 분류 알고리즘
from sklearn.neighbors import KNeighborsRegressor # KNR 알고리즘
Part 2. 데이터를 불러오고, 데이터가 어떻게 생겼는지 알아본다.
data = pd.read_csv('/kaggle/input/random-linear-regression/test.csv').to_numpy()
print(data)
train_input, test_input, train_target, test_target = train_test_split(
data[:, 0], data[:, 1])
array([[ 77. , 79.77515201],
[ 21. , 23.17727887],
[ 22. , 25.60926156],
[ 20. , 17.85738813],
[ 36. , 41.84986439],
[ 15. , 9.80523488],
[ 62. , 58.87465933],
[ 95. , 97.61793701],
[ 20. , 18.39512747],
[ 5. , 8.74674765],
[ 4. , 2.81141583],
[ 19. , 17.09537241],
[ 96. , 95.14907176],
[ 62. , 61.38800663],
[ 36. , 40.24701716],
[ 15. , 14.82248589],
[ 65. , 66.95806869],
[ 14. , 16.63507984],
[ 87. , 90.65513736],
[ 69. , 77.22982636],
[ 89. , 92.11906278],
[ 51. , 46.91387709],
[ 89. , 89.82634442],
[ 27. , 21.71380347],
[ 97. , 97.41206981],
[ 58. , 57.01631363],
[ 79. , 78.31056542],
[ 21. , 19.1315097 ],
[ 93. , 93.03483388],
[ 27. , 26.59112396],
[ 99. , 97.55155344],
[ 31. , 31.43524822],
[ 33. , 35.12724777],
[ 80. , 78.61042432],
[ 28. , 33.07112825],
[ 47. , 51.69967172],
[ 53. , 53.62235225],
[ 69. , 69.46306072],
[ 28. , 27.42497237],
[ 33. , 36.34644189],
[ 91. , 95.06140858],
[ 71. , 68.16724757],
[ 50. , 50.96155532],
[ 76. , 78.04237454],
[ 4. , 5.60766487],
[ 37. , 36.11334779],
[ 70. , 67.2352155 ],
[ 68. , 65.01324035],
[ 40. , 38.14753871],
[ 35. , 34.31141446],
[ 94. , 95.28503937],
[ 88. , 87.84749912],
[ 52. , 54.08170635],
[ 31. , 31.93063515],
[ 59. , 59.61247085],
[ 0. , -1.04011421],
[ 39. , 47.49374765],
[ 64. , 62.60089773],
[ 69. , 70.9146434 ],
[ 57. , 56.14834113],
[ 13. , 14.05572877],
[ 72. , 68.11367147],
[ 76. , 75.59701346],
[ 61. , 59.225745 ],
[ 82. , 85.45504157],
[ 18. , 17.76197116],
[ 41. , 38.68888682],
[ 50. , 50.96343637],
[ 55. , 51.83503872],
[ 13. , 17.0761107 ],
[ 46. , 46.56141773],
[ 13. , 10.34754461],
[ 79. , 77.91032969],
[ 53. , 50.17008622],
[ 15. , 13.25690647],
[ 28. , 31.32274932],
[ 81. , 73.9308764 ],
[ 69. , 74.45114379],
[ 52. , 52.01932286],
[ 84. , 83.68820499],
[ 68. , 70.3698748 ],
[ 27. , 23.44479161],
[ 56. , 49.83051801],
[ 48. , 49.88226593],
[ 40. , 41.04525583],
[ 39. , 33.37834391],
[ 82. , 81.29750133],
[100. , 105.5918375 ],
[ 59. , 56.82457013],
[ 43. , 48.67252645],
[ 67. , 67.02150613],
[ 38. , 38.43076389],
[ 63. , 58.61466887],
[ 91. , 89.12377509],
[ 60. , 60.9105427 ],
[ 14. , 13.83959878],
[ 21. , 16.89085185],
[ 87. , 84.06676818],
[ 73. , 70.34969772],
[ 32. , 33.38474138],
[ 2. , -1.63296825],
[ 82. , 88.54475895],
[ 19. , 17.44047622],
[ 74. , 75.69298554],
[ 42. , 41.97607107],
[ 12. , 12.59244741],
[ 1. , 0.27530726],
[ 90. , 98.13258005],
[ 89. , 87.45721555],
[ 0. , -2.34473854],
[ 41. , 39.3294153 ],
[ 16. , 16.68715211],
[ 94. , 96.58888601],
[ 97. , 97.70342201],
[ 66. , 67.01715955],
[ 24. , 25.63476257],
[ 17. , 13.41310757],
[ 90. , 95.15647284],
[ 13. , 9.74416426],
[ 0. , -3.46788379],
[ 64. , 62.82816355],
[ 96. , 97.27405461],
[ 98. , 95.58017185],
[ 12. , 7.46850184],
[ 41. , 45.44599591],
[ 47. , 46.69013968],
[ 78. , 74.4993599 ],
[ 20. , 21.63500655],
[ 89. , 91.59548851],
[ 29. , 26.49487961],
[ 64. , 67.38654703],
[ 75. , 74.25362837],
[ 12. , 12.07991648],
[ 25. , 21.32273728],
[ 28. , 29.31770045],
[ 30. , 26.48713683],
[ 65. , 68.94699774],
[ 59. , 59.10598995],
[ 64. , 64.37521087],
[ 53. , 60.20758349],
[ 71. , 70.34329706],
[ 97. , 97.1082562 ],
[ 73. , 75.7584178 ],
[ 9. , 10.80462727],
[ 12. , 12.11219941],
[ 63. , 63.28312382],
[ 99. , 98.03017721],
[ 60. , 63.19354354],
[ 35. , 34.8534823 ],
[ 2. , -2.81991397],
[ 60. , 59.8313966 ],
[ 32. , 29.38505024],
[ 94. , 97.00148372],
[ 84. , 85.18657275],
[ 63. , 61.74063192],
[ 22. , 18.84798163],
[ 81. , 78.79008525],
[ 93. , 95.12400481],
[ 33. , 30.48881287],
[ 7. , 10.41468095],
[ 42. , 38.98317436],
[ 46. , 46.11021062],
[ 54. , 52.45103628],
[ 16. , 21.16523945],
[ 49. , 52.28620611],
[ 43. , 44.18863945],
[ 95. , 97.13832018],
[ 66. , 67.22008001],
[ 21. , 18.98322306],
[ 35. , 24.3884599 ],
[ 80. , 79.44769523],
[ 37. , 40.03504862],
[ 54. , 53.32005764],
[ 56. , 54.55446979],
[ 1. , -2.7611826 ],
[ 32. , 37.80182795],
[ 58. , 57.48741435],
[ 32. , 36.06292994],
[ 46. , 49.83538167],
[ 72. , 74.68953276],
[ 17. , 14.86159401],
[ 97. , 101.0697879 ],
[ 93. , 99.43577876],
[ 91. , 91.69240746],
[ 37. , 34.12473248],
[ 4. , 6.07939007],
[ 54. , 59.07247174],
[ 51. , 56.43046022],
[ 27. , 30.49412933],
[ 46. , 48.35172635],
[ 92. , 89.73153611],
[ 73. , 72.86282528],
[ 77. , 80.97144285],
[ 91. , 91.36566374],
[ 61. , 60.07137496],
[ 99. , 99.87382707],
[ 4. , 8.65571417],
[ 72. , 69.39858505],
[ 19. , 19.38780134],
[ 57. , 53.11628433],
[ 78. , 78.39683006],
[ 26. , 25.75612514],
[ 74. , 75.07484683],
[ 90. , 92.88772282],
[ 66. , 69.45498498],
[ 13. , 13.12109842],
[ 40. , 48.09843134],
[ 77. , 79.3142548 ],
[ 67. , 68.48820749],
[ 75. , 73.2300846 ],
[ 23. , 24.68362712],
[ 45. , 41.90368917],
[ 59. , 62.22635684],
[ 44. , 45.96396877],
[ 23. , 23.52647153],
[ 55. , 51.80035866],
[ 55. , 51.10774273],
[ 95. , 95.79747345],
[ 12. , 9.24113898],
[ 4. , 7.64652976],
[ 7. , 9.28169975],
[100. , 103.5266162 ],
[ 48. , 47.41006725],
[ 42. , 42.03835773],
[ 96. , 96.11982476],
[ 39. , 38.05766408],
[100. , 105.4503788 ],
[ 87. , 88.80306911],
[ 14. , 15.49301141],
[ 14. , 12.42624606],
[ 37. , 40.00709598],
[ 5. , 5.6340309 ],
[ 88. , 87.36938931],
[ 91. , 89.73951993],
[ 65. , 66.61499643],
[ 74. , 72.9138853 ],
[ 56. , 57.19103506],
[ 16. , 11.21710477],
[ 5. , 0.67607675],
[ 28. , 28.15668543],
[ 92. , 95.3958003 ],
[ 46. , 52.05490703],
[ 54. , 59.70864577],
[ 39. , 36.79224762],
[ 44. , 37.08457698],
[ 31. , 24.18437976],
[ 68. , 67.28725332],
[ 86. , 82.870594 ],
[ 90. , 89.899991 ],
[ 38. , 36.94173178],
[ 21. , 19.87562242],
[ 95. , 90.71481654],
[ 56. , 61.09367762],
[ 60. , 60.11134958],
[ 65. , 64.83296316],
[ 78. , 81.40381769],
[ 89. , 92.40217686],
[ 6. , 2.57662538],
[ 67. , 63.80768172],
[ 36. , 38.67780759],
[ 16. , 16.82839701],
[100. , 99.78687252],
[ 45. , 44.68913433],
[ 73. , 71.00377824],
[ 57. , 51.57326718],
[ 20. , 19.87846479],
[ 76. , 79.50341495],
[ 34. , 34.58876491],
[ 55. , 55.7383467 ],
[ 72. , 68.19721905],
[ 55. , 55.81628509],
[ 8. , 9.3914168 ],
[ 56. , 56.01448111],
[ 72. , 77.9969477 ],
[ 58. , 55.37049953],
[ 6. , 11.89457829],
[ 96. , 94.79081712],
[ 23. , 25.69041546],
[ 58. , 53.52042319],
[ 23. , 18.31396758],
[ 19. , 21.42637785],
[ 25. , 30.41303282],
[ 64. , 67.68142149],
[ 21. , 17.0854783 ],
[ 59. , 60.91792707],
[ 19. , 14.99514319],
[ 16. , 16.74923937],
[ 42. , 41.46923883],
[ 43. , 42.84526108],
[ 61. , 59.12912974],
[ 92. , 91.30863673],
[ 11. , 8.67333636],
[ 41. , 39.31485292],
[ 1. , 5.3136862 ],
[ 8. , 5.40522052],
[ 71. , 68.5458879 ],
[ 46. , 47.33487629],
[ 55. , 54.09063686],
[ 62. , 63.29717058],
[ 47. , 52.45946688]])
Part 3. 해당 데이터를 불러온 이후에 머신러닝 알고리즘 모델에 훈련시킬 수 있도록 데이터를 전처리 시킨다.
# 머신러닝 모델을 훈련시키기 위해서는 반드시 아래 과정을 거쳐야 한다.
train_input = train_input.reshape(-1, 1)
train_target = train_target.reshape(-1, 1)
test_input = test_input.reshape(-1, 1)
test_target = test_target.reshape(-1, 1)
Part 4. 훈련세트와 테스트세트가 어떻게 구성되어 있는지를 그래프화 시켜본다.
fig, ax = plt.subplots(ncols=2, figsize=(10, 4))
ax[0].scatter(train_input, train_target, label="train set")
ax[1].scatter(test_input, test_target, label="test set")
ax[0].legend()
ax[1].legend()
훈련세트와 테스트 세트를 한 그래프 안에 보이도록 해서 훈련세트와 테스트세트가 골고루 분배되어 있는지를 파악하자. 아래에 분류된 테스트세트는 데이터가 한 곳에서 밀집되어 있지 않기 때문에 좋게 분배되어 있다고 볼 수 있다.
fig, ax = plt.subplots()
ax.scatter(train_input, train_target, label="train set")
ax.scatter(test_input, test_target, label="test set")
ax.legend()
Part 5. 1부터 10까지 어떤 수가 적절할지 판단해야 한다.
train_acc = []
test_acc = []
neighbors = np.arange(1, 11)
for n in neighbors:
knr = KNeighborsRegressor(n_neighbors=n)
knr.fit(train_input, train_target)
train_acc.append(knr.score(train_input, train_target))
test_acc.append(knr.score(test_input, test_target))
print(train_acc)
print(test_acc)
[0.9867396996216738, 0.9923204650979393, 0.9926051629898496, 0.9918538603016698, 0.9911689694981655, 0.9908110385132987, 0.9901535049984052, 0.989689046831241, 0.9894069788417342, 0.9891426614162112]
[0.9770732287985622, 0.984061677207744, 0.984975663345472, 0.9864692391556494, 0.9868460867378576, 0.9870437381016879, 0.9860105642871373, 0.9865222622585592, 0.9861843204582866, 0.9858521274157183]
다음과 같이 매우 높은 값들이 나오는 것을 볼 수 있다.
Part 6. 해당 결과 값을 바탕으로 그래프를 그려보자.
fig, ax = plt.subplots()
ax.plot(neighbors, train_acc, label="train accuracy")
ax.plot(neighbors, test_acc, label="test accuracy")
ax.legend()
대부분의 값이 높지만, n_neighbors가 6정도 되는 경우에 훈련세트에 대한 정확도와 테스트세트에 대한 정확도 둘 다 좋은 결과를 보이는 것을 알 수 있다.
Part 7. n_neighbors=6으로 훈련시키자.
knr = KNeighborsRegressor(n_neighbors=6)
knr.fit(train_input, train_target)
knr.score(train_input, train_target)
0.9908110385132987
knr.score(test_input, test_target)
0.9870437381016879
역시 높은 결과 값을 보이는 것을 알 수 있다. 이정도면 매우 만족스러운 결과이다.
Part 8. 87에 대한 결과값을 예측해보자.
knr.predict([[87]])
array([[89.0105394]])
87이라는 입력 값에 89.0.. 이라는 결과가 나오는 것을 볼 수 있다.
Part 9. (87, 89)의 값을 그래프에 찍어보고, 이 값에 영향을 끼친 이웃들을 그래프에 표시해보자.
neighbors, index = knr.kneighbors([[87]]) # neighbors는 이웃까지의 거리를 나타내고 index는 번호를 나타낸다.
fig, ax = plt.subplots()
ax.scatter(train_input, train_target, marker='.', label='Train dataset')
ax.scatter(87, 89.0105394, marker='x')
ax.scatter(train_input[index], train_target[index], marker='*')
초록색 별이 바로 87이라는 입력값에 대한 결과값을 도출할 때 영향을 미쳤던 요소들이다.
print(neighbors)
array([[0., 0., 1., 2., 2., 2.]])
다음과 같은 결과가 나오는데 87에 가장 가까운 데이터와의 거리가 0부터 2까지 수이다.
Part 10. 해당 결과 값이 나오는 이유는 무엇일까?
분명히 서문에서 설명했듯이 KNeighborsRegressor는 알고 싶은 데이터에 가장 가까운 데이터들의 합을 총 개수로 나누어 구한다고 했다.
answer = sum(train_target[index]) / len(train_target[index])
print(sum(answer) / len(answer))
array([89.0105394])
따라서 87에 이웃한 6개의 데이터의 합을 6으로 나눠주게 되면 해당 알고리즘에서 예측했던 값과 똑같은 값이 나오게 된다.
우리가 공부했던 KNN Classifier 알고리즘과 KNR Regressor 알고리즘의 강점과 약점은 무엇일까?
먼저 KNN 알고리즘에서 가장 중요한 두가지는 가까운 이웃의 개수와 이웃 데이터까지의 거리를 어떻게 측정하는지 이다. 예를 들어서 작은 숫자를 이웃으로 두는 것보다 5개 정도의 이웃을 두었을 때 좋은 결과를 보였다. 하지만 이 값은 반드시 조정해야하는 요소이다. 그리고 기본적으로 유클리드 거리를 통해서 좋은 이웃 수를 구하곤 했는데, 거리를 잴 수 있는 더 많은 방법이 있기 때문에 다양한 방법을 시도해서 해당 알고리즘 모델에 효율적인 방법을 찾아야 한다.
KNN 알고리즘의 강점은 이해하기 쉽다는 것이다. 그리고 다양하게 조건들을 조정해주지 않아도 좋은 효율을 제공한다는 것이다. 해당 알고리즘을 사용하는 것은 앞으로 등장할 다양한 알고리즘들을 배우는데 기초가 될 것이다. 그리고 KNN 알고리즘을 사용할 때 가장 중요한 것은 데이터의 전처리이다. 대부분의 머신러닝 알고리즘을 구현할때 전처리는 중요한 요소이지만, KNN 알고리즘에서는 특히 더 그렇다. 하지만 이에 반해 약점은 데이터가 많으면 결과 값을 도출하는데 오래 걸리 뿐만 아니라 아래 그래프와 같이 주어진 데이터 밖의 데이터를 예상하려고 할 경우에는 예측이 불가능하기 때문에 명심해야한다.