You can download data from here: https://www.kaggle.com/prathamtripathi/drug-classification
library(tidyverse)
## -- Attaching packages --------------------------------------- tidyverse 1.3.1 --
## v ggplot2 3.3.5 v purrr 0.3.4
## v tibble 3.1.2 v dplyr 1.0.7
## v tidyr 1.1.3 v stringr 1.4.0
## v readr 1.4.0 v forcats 0.5.1
## -- Conflicts ------------------------------------------ tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
df <- read_csv("drug200.csv")
##
## -- Column specification --------------------------------------------------------
## cols(
## Age = col_double(),
## Sex = col_character(),
## BP = col_character(),
## Cholesterol = col_character(),
## Na_to_K = col_double(),
## Drug = col_character()
## )
# looking for NAs
apply(is.na(df), 2, sum)
## Age Sex BP Cholesterol Na_to_K Drug
## 0 0 0 0 0 0
Visualization of target variable :
t <- table(df$Drug)
t <- as.data.frame(t)
colnames(t) <- c("Drug","count")
ggplot(t, aes(x=Drug, y=count, fill=Drug)) +
geom_bar(stat="identity", color="black") +
theme_minimal() +
geom_text(aes(label=count), vjust=-0.6, size=6) +
scale_fill_brewer(palette="Set1")
Remember, KNN is calculated by the Euclidean distance between points. So the calculation requiries numbers:
df$Drug<-as.factor(df$Drug)
df$Sex<-as.factor(df$Sex)
df$Cholesterol<-as.numeric(as_factor(df$Cholesterol))
df$BP<-as.numeric(as_factor(df$BP))
df$Sex<-as.numeric(df$Sex)
Euclidean distance is a little sensitive, so we need to normalize the values of each variable to the range 0:1.
normalize <- function(x) {
return((x - min(x)) / (max(x) - min(x)))
}
df[,-c(6)]<-apply(df[,-c(6)],2,normalize)
# splitting with caTools
library(caTools)
set.seed(123)
df<-as.data.frame(df)
sample<-sample.split(df,SplitRatio = 0.85)
train<-subset(df,sample==T)
test<-subset(df,sample==F)
library(kableExtra)
##
## Attaching package: 'kableExtra'
## The following object is masked from 'package:dplyr':
##
## group_rows
kable(test)
Age | Sex | BP | Cholesterol | Na_to_K | Drug | |
---|---|---|---|---|---|---|
5 | 0.7796610 | 0 | 0.5 | 0 | 0.3681906 | DrugY |
11 | 0.5423729 | 0 | 0.5 | 0 | 0.1719307 | drugC |
17 | 0.9152542 | 1 | 0.5 | 1 | 0.1621740 | drugX |
23 | 0.5423729 | 1 | 0.5 | 1 | 0.7598662 | DrugY |
29 | 0.4067797 | 0 | 0.5 | 1 | 0.5137282 | DrugY |
35 | 0.6440678 | 1 | 1.0 | 0 | 0.2459191 | drugX |
41 | 0.9830508 | 0 | 1.0 | 0 | 0.4050285 | DrugY |
47 | 0.3728814 | 0 | 0.0 | 0 | 0.2133342 | drugA |
53 | 0.7966102 | 1 | 0.5 | 1 | 0.6540121 | DrugY |
59 | 0.7627119 | 1 | 1.0 | 1 | 0.1195197 | drugX |
65 | 0.7627119 | 0 | 0.0 | 0 | 0.2199637 | drugB |
71 | 0.9322034 | 1 | 0.0 | 0 | 0.2407280 | drugB |
77 | 0.3559322 | 0 | 0.0 | 0 | 0.1541372 | drugA |
83 | 0.2881356 | 0 | 0.5 | 0 | 0.1076678 | drugC |
89 | 0.3728814 | 0 | 0.0 | 1 | 0.5260492 | DrugY |
95 | 0.6949153 | 1 | 0.5 | 0 | 0.2735005 | DrugY |
101 | 0.2711864 | 1 | 0.0 | 1 | 0.1751829 | drugA |
107 | 0.1186441 | 1 | 1.0 | 0 | 0.1777472 | drugX |
113 | 0.3389831 | 1 | 0.5 | 1 | 0.0907186 | drugX |
119 | 0.2881356 | 0 | 0.0 | 1 | 0.1258052 | drugA |
125 | 0.6440678 | 0 | 0.0 | 1 | 0.1946964 | drugB |
131 | 0.9322034 | 0 | 1.0 | 0 | 0.4446807 | DrugY |
137 | 0.6779661 | 0 | 0.0 | 0 | 0.1472262 | drugB |
143 | 0.7627119 | 1 | 0.0 | 1 | 0.0735506 | drugB |
149 | 0.7796610 | 0 | 0.5 | 1 | 0.0334918 | drugX |
155 | 0.3728814 | 1 | 0.5 | 1 | 0.3269435 | DrugY |
161 | 0.2542373 | 0 | 1.0 | 0 | 0.1305272 | drugX |
167 | 0.7288136 | 0 | 0.5 | 0 | 0.6371881 | DrugY |
173 | 0.4067797 | 0 | 1.0 | 1 | 0.3426105 | DrugY |
179 | 0.4067797 | 1 | 1.0 | 0 | 0.3033335 | DrugY |
185 | 0.0508475 | 0 | 0.0 | 0 | 0.9668835 | DrugY |
191 | 0.7288136 | 1 | 0.0 | 0 | 0.3978360 | DrugY |
197 | 0.0169492 | 1 | 0.5 | 0 | 0.1794046 | drugC |
Sanitiy check: we’ve normalized our numeric features so certain features don’t dominate the euclidian distance, and we’ve coded our categorical features as dummy variables so that they can be included in our distance calculations. Then we split our data for the model evaluation.
In next step we need the cross validation values. I choose 100 for number of repeats however,because a recent article recommend that we should use it as 100 (see: (Fränti and Sieranoja 2019)). With number 100, we are going to use a 100-fold cross validation. With this process we could determine the best k values.
library(rpart)
library(caret)
## Loading required package: lattice
##
## Attaching package: 'caret'
## The following object is masked from 'package:purrr':
##
## lift
fit_control<-trainControl(method = "repeatedcv",number = 100,
repeats=100)
set.seed(123)
model<-caret::train(Drug~.,data=train,method="knn",trControl=fit_control,tuneGrid=expand.grid(k=1:20))
## Warning in nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo, :
## There were missing values in resampled performance measures.
model
## k-Nearest Neighbors
##
## 167 samples
## 5 predictor
## 5 classes: 'drugA', 'drugB', 'drugC', 'drugX', 'DrugY'
##
## No pre-processing
## Resampling: Cross-Validated (100 fold, repeated 100 times)
## Summary of sample sizes: 163, 165, 165, 164, 165, 165, ...
## Resampling results across tuning parameters:
##
## k Accuracy Kappa
## 1 0.8572144 0.7198673
## 2 0.7845243 0.5955541
## 3 0.7666090 0.5689619
## 4 0.7592450 0.5556471
## 5 0.7661586 0.5714250
## 6 0.7465099 0.5362901
## 7 0.7408108 0.5213184
## 8 0.7126288 0.4815372
## 9 0.7193153 0.4871902
## 10 0.7150829 0.4793582
## 11 0.7089027 0.4689272
## 12 0.7171117 0.4723061
## 13 0.7191009 0.4712694
## 14 0.7174342 0.4607085
## 15 0.7163153 0.4557729
## 16 0.7135279 0.4411374
## 17 0.7012955 0.4139307
## 18 0.6988523 0.4008048
## 19 0.6724486 0.3543964
## 20 0.6580216 0.3254944
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 1.
plot(model)
library(class)
set.seed(4)
modelknn<-knn(train=train[,1:5],test = test[,1:5],cl=train$Drug,k=1)
confusionMatrix(test$Drug,modelknn)
## Confusion Matrix and Statistics
##
## Reference
## Prediction drugA drugB drugC drugX DrugY
## drugA 4 0 0 0 0
## drugB 1 3 0 0 1
## drugC 0 0 3 0 0
## drugX 0 0 0 6 1
## DrugY 0 0 2 1 11
##
## Overall Statistics
##
## Accuracy : 0.8182
## 95% CI : (0.6454, 0.9302)
## No Information Rate : 0.3939
## P-Value [Acc > NIR] : 7.571e-07
##
## Kappa : 0.755
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: drugA Class: drugB Class: drugC Class: drugX
## Sensitivity 0.8000 1.00000 0.60000 0.8571
## Specificity 1.0000 0.93333 1.00000 0.9615
## Pos Pred Value 1.0000 0.60000 1.00000 0.8571
## Neg Pred Value 0.9655 1.00000 0.93333 0.9615
## Prevalence 0.1515 0.09091 0.15152 0.2121
## Detection Rate 0.1212 0.09091 0.09091 0.1818
## Detection Prevalence 0.1212 0.15152 0.09091 0.2121
## Balanced Accuracy 0.9000 0.96667 0.80000 0.9093
## Class: DrugY
## Sensitivity 0.8462
## Specificity 0.8500
## Pos Pred Value 0.7857
## Neg Pred Value 0.8947
## Prevalence 0.3939
## Detection Rate 0.3333
## Detection Prevalence 0.4242
## Balanced Accuracy 0.8481
Our predictive accuracy is 81.82 percent, as we see above. This is pretty good performance, considering that simplicity of knn.
In sum, The k-nearest neighbors classification approach is rather simple to understand and implement. Yet it is very effective as we can see. The training phase is very fast and KNN makes no assumptions about the underlying data distribution so we could use it in various problems. But of course there are some weakness also. First thing that comes the my mind is selecting K is often arbitrary. Without scaling, k-NN cannot handle nominal or outlier data. And it can’t work with missing data.
library(plyr)
## ------------------------------------------------------------------------------
## You have loaded plyr after dplyr - this is likely to cause problems.
## If you need functions from both plyr and dplyr, please load plyr first, then dplyr:
## library(plyr); library(dplyr)
## ------------------------------------------------------------------------------
##
## Attaching package: 'plyr'
## The following objects are masked from 'package:dplyr':
##
## arrange, count, desc, failwith, id, mutate, rename, summarise,
## summarize
## The following object is masked from 'package:purrr':
##
## compact
plot.df = data.frame(test[,1:5], predicted = modelknn)
plot.df1 = data.frame(x = plot.df$Na_to_K,
y = plot.df$Age,
predicted = plot.df$predicted)
find_hull = function(df) df[chull(df$x, df$y), ]
boundary = ddply(plot.df1, .variables = "predicted", .fun = find_hull)
ggplot(plot.df, aes(Na_to_K, Age, color = predicted, fill = predicted)) +
geom_point(size = 5) +
geom_polygon(data = boundary, aes(x,y), alpha = 0.5)