R Regression: Linear Regression
Jump to navigation
Jump to search
# Ref: http://www.sthda.com/english/articles/40-regression-analysis/165-linear-regression-essentials-in-r/ if(!require(devtools)) install.packages("devtools") devtools::install_github("kassambara/datarium") data("marketing", package = "datarium") head(marketing, 3) data("swiss") head(swiss, 3) data("Boston", package = "MASS") head(Boston, 3) install.packages("tidyverse") install.packages("caret") library(tidyverse) library(caret) theme_set(theme_bw()) # Load the data data("marketing", package = "datarium") # Inspect the data sample_n(marketing, 3) # Split the data into training and test set set.seed(123) training.samples <- marketing$sales %>% createDataPartition(p = 0.8, list = FALSE) train.data <- marketing[training.samples, ] test.data <- marketing[-training.samples, ]
# Build the model model <- lm(sales ~., data = train.data) # Summarize the model summary(model) # Make predictions predictions <- model %>% predict(test.data) # Model performance # (a) Prediction error, RMSE RMSE(predictions, test.data$sales) # (b) R-square R2(predictions, test.data$sales) # Simple Linear Regression model <- lm(sales ~ youtube, data = train.data) summary(model)$coef newdata <- data.frame(youtube = c(0, 1000)) model %>% predict(newdata)
# Multiple Linear Regression model <- lm(sales ~ youtube + facebook + newspaper, data = train.data) summary(model)$coef # New advertising budgets newdata <- data.frame( youtube = 2000, facebook = 1000, newspaper = 1000 ) # Predict sales values model %>% predict(newdata) # model summary summary(model) # Coefficients significance summary(model)$coef # CHANGE MODEL remove newspaper model <- lm(sales ~ youtube + facebook, data = train.data) summary(model)
# Make predictions predictions <- model %>% predict(test.data) # Model performance # (a) Compute the prediction error, RMSE RMSE(predictions, test.data$sales) # (b) Compute R-square R2(predictions, test.data$sales) # PLOT ggplot(marketing, aes(x = youtube, y = sales)) + geom_point() + stat_smooth()