浩哥强!
浩哥交的
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 29 15:55:19 2019
@author: lenovo
"""
import numpy as np
from math import sqrt
import operator as opt
def fun_for_judge(database,datatest,data_index,K):
diff = (database-datatest)**2#求坐标差
distance = (diff.sum(axis=1))**0.5#求距离
Index_have_sorted = distance.argsort()#根据数值,对坐标排序
Index_Former_K = Index_have_sorted[:K]#前K个的坐标
A_count=0#统计A,B类
B_count=0
for i in Index_Former_K:
if data_index[i]=='A':
A_count+=1
else :
B_count+=1
if A_count>B_count:
return True#返回A
else :
return False#返回B类
if __name__ == "__main__":
database = np.array([[1.2,3.1],[2.5,1.5]])#训练集
data_index = ['A','B']
datatest = np.array([2.1,2.6])#测试集
K=1
re = fun_for_judge(database,datatest,data_index,K)
print("你的测试点",datatest,end='')
if re :
print("属于A类")
else:
print("属于B类")