i'm having the following problem:
Context:
I'm using gtsummary to explore frequencies in a dataframe using cross variables.
Here's my desire output:
So that i have a main variable tobgp and its cross by multiple variables like agegp and algp
Attempt:
this is what i've done so far. Using the esoph data from the package The R Datasets Package (datasets).
pacman::p_load(tidyverse, gt, gtsummary)
multiple_table<-function(data, var){
t0<- data %>%
select({{var}}) %>%
gtsummary::tbl_summary(statistic = all_categorical()~ "{p}% ({n})",
digits = list(everything() ~ c(2, 0))) %>%
modify_header(label ~ "") %>%
bold_labels()
#agep
t1<-data %>%
select({{var}}, agegp) %>%
gtsummary::tbl_summary(by = agegp, statistic = all_categorical()~ "{p}% ({n})",
digits = list(everything() ~ c(2, 0)))
#alcgp
t2<-data %>%
select({{var}}, alcgp) %>%
gtsummary::tbl_summary(by = alcgp, statistic = all_categorical()~ "{p}% ({n})",
digits = list(everything() ~ c(2, 0)))
#MERGE
tbl_merge(tbls = list(t0,t1,t2),
tab_spanner = c("**Total**", "**agegp**", "**algp**")) %>%
as_gt() %>%
gt::tab_source_note(gt::md("*Fuente: Empresa1*"))
}
esoph %>%
multiple_table(tobgp)
The problem with my code so far is that is specific for the crosses, to add more cross variables i have to modify the function i created which is not so friendly.
Request:
Create a function so that you can create the desire output with one line of code. Like this for example:
multiple_table(data, main, by)
esoph %>%
multiple_table(main=tobgp, by=c(agegp, algp)
So that if i want to use other variables to cross by i only have to change the by=c() argument.
In order to be easy to do something like:
esoph %>%
multiple_table(main=tobgp, by=c(agegp, algp, variable1, variable2)
Notes:
I've tried other functions inside gtsummary like tbl_strata which can use two variables as crosses, but doesn't suit my needs because it mixes the two cross variables like this:
This is not what i'm looking for. As you can see, Grade divides the percentage of Drug test by each Grade. This example is taken from gtsummary vignette: https://www.danieldsjoberg.com/gtsummary/reference/tbl_strata.html
I think the solution for my problem could involve some workaround with purrr, or apply, i've tried some but i'm not very good using lists and iterations.
That's it. Thanks very much for listening and i hope i've been very clear. If not, feel free to ask.
Answers 28/03/22
Since i posted my question i've recieve to different approach answers which both work perfectly. Feel free to use the one that suits you. Thanks Mike for the answer in StackOverflow and thanks Tan, June C, Tyler Grant Smith for the answer in the Slack R4DS Community. In my case i would stick with the approach 3.
Approach 1: The Mike approach
library(gtsummary)
library(dplyr)
esoph <- mutate(esoph,
ncases = ifelse(ncases > 2, "High","Low"))
multiple_table<-function(data, var, vars){
t0 <- data %>%
select( var ) %>%
gtsummary::tbl_summary(statistic = all_categorical()~ "{p}% ({n})",
digits = list(everything() ~ c(2, 0))) %>%
modify_header(label ~ "") %>%
bold_labels()
tlist <- lapply(vars,function(y){
data %>%
select( var , y ) %>%
gtsummary::tbl_summary(by = y , statistic = all_categorical()~ "{p}% ({n})",
digits = list(everything() ~ c(2, 0)))
})
tabspannername <- c("**Total**", paste0("**",vars,"**"))
tlist2 <- append(list(t0), tlist,1)
tbl_merge(tbls = tlist2
,tab_spanner = tabspannername
) %>%
as_gt() %>%
gt::tab_source_note(gt::md("*Fuente: Empresa1*"))
}
multiple_table(data = esoph, var = "tobgp", vars = c("agegp", "alcgp","ncases"))
Approach 2: The Tan approach
library(tidyverse)
library(gt)
library(gtsummary)
esoph
fn_subtable <- function(data, main, sub){
data %>%
dplyr::select({{main}},{{sub}}) %>%
gtsummary::tbl_summary(
by = {{sub}},
statistic = gtsummary::all_categorical()~ "{p}% ({n})",
digits = list(dplyr::everything() ~ c(2, 0)))
}
fn_table <-function(data, main_var, sub_vars){
t0 <- data %>%
dplyr::select({{main_var}}) %>%
gtsummary::tbl_summary(statistic = gtsummary::all_categorical() ~ "{p}% ({n})",
digits = list(dplyr::everything() ~ c(2, 0))) %>%
gtsummary::modify_header(label ~ "") %>%
gtsummary::bold_labels()
sub_tables <- purrr::map(sub_vars, ~fn_subtable(data = data, main = main_var, sub = .x))
#MERGE
tbls <- c(list(t0), sub_tables) %>%
gtsummary::tbl_merge(tab_spanner = c("**Total**", paste0("**",sub_vars,"**"))) %>%
gtsummary::as_gt() %>%
gt::tab_source_note(gt::md("*Fuente: Empresa1*"))
tbls
}
esoph %>% fn_table("tobgp", c("agegp", "alcgp"))
Approach 3: The June C - Tyler Grant Smith approach
library(tidyverse)
library(gt)
library(gtsummary)
fn_subtable <- function(data, main, sub){
data %>%
dplyr::select({{main}},{{sub}}) %>%
gtsummary::tbl_summary(
by = {{sub}},
statistic = gtsummary::all_categorical()~ "{p}% ({n})",
digits = list(dplyr::everything() ~ c(2, 0)))
}
fn_table3 <- function(data, main_var, sub_vars){
main_var <- rlang::enexpr(main_var)
sub_vars_expr <- rlang::enexpr(sub_vars) # 1. Capture `list(...)` call as expression
sub_vars_args <- rlang::call_args(sub_vars_expr) # 2. Pull out the arguments (they're now also exprs)
sub_vars_fn <- rlang::call_fn(sub_vars_expr) # 3. Pull out the fn call
# 4. Evaluate the fn with expr-ed arguments (this becomes `list( expr(agegp), expr(alcgp) )` )
sub_vars_reconstructed <- rlang::exec(sub_vars_fn, !!!sub_vars_args)
# --- sub_vars replaced with sub_vars_reconstructed from here onwards ---
t0 <- data %>%
dplyr::select({{main_var}}) %>%
gtsummary::tbl_summary(statistic = gtsummary::all_categorical() ~ "{p}% ({n})",
digits = list(dplyr::everything() ~ c(2, 0))) %>%
gtsummary::modify_header(label ~ "") %>%
gtsummary::bold_labels()
sub_tables <- purrr::map(sub_vars_reconstructed, ~fn_subtable(data = data, main = main_var, sub = .x))
tbls <- c(list(t0), sub_tables) %>%
gtsummary::tbl_merge(tab_spanner = c("**Total**", paste0("**",sub_vars_reconstructed,"**"))) %>%
gtsummary::as_gt() %>%
gt::tab_source_note(gt::md("*Fuente: Empresa1*"))
tbls
}
fn_table3(esoph,tobgp,list(agegp,alcgp))
Thanks very much and i hope this could be implemented as a function inside the gtsummary package because is very useful to explore frequencies with different cross variables.
you are pretty close and only needed a few modifications. the major change is adding in an lapply() to loop through the vars input to create a list of tbl_summary objects. Then I create the tab spanner names from the inputs of vars and append the t0 table to the list created by the lapply(). then you can pass tlist2 to tbl_merge() with the names created with tabspannername to dynamically label the tables.
library(gtsummary)
library(dplyr)
esoph <- mutate(esoph,
ncases = ifelse(ncases > 2, "High","Low"))
multiple_table<-function(data, var, vars){
t0 <- data %>%
select( var ) %>%
gtsummary::tbl_summary(statistic = all_categorical()~ "{p}% ({n})",
digits = list(everything() ~ c(2, 0))) %>%
modify_header(label ~ "") %>%
bold_labels()
tlist <- lapply(vars,function(y){
esoph %>%
select( var , y ) %>%
gtsummary::tbl_summary(by = y , statistic = all_categorical()~ "{p}% ({n})",
digits = list(everything() ~ c(2, 0)))
})
tabspannername <- c("**Total**", paste0("**",vars,"**"))
tlist2 <- append(list(t0), tlist,1)
tbl_merge(tbls = tlist2
,tab_spanner = tabspannername
) %>%
as_gt() %>%
gt::tab_source_note(gt::md("*Fuente: Empresa1*"))
}
x <- multiple_table(data = esoph, var = "tobgp", vars = c("agegp", "alcgp","ncases"))
Related
I'm trying to use the ARIMA function from the fable package. I'd like to test, using cross validation, every specification, given by the pdqPDQ data.frame rows, using a multisession plan from the future package. I will then make forecasts and later calculate accuracy measures.
ARIMA function cannot see the pdqPDQ object. I'm aware of the future missing globals issues, and maybe that's the case here (?).
Any ideas for how I could solve this?
library(GetBCBData)
library(lubridate)
library(tsibble)
library(fable)
library(tidyr)
library(future)
library(dplyr)
#============================================================#
#Data ----
#============================================================#
ipca <- gbcbd_get_series(c(433, 4449, 10844, 11428, 27863, 27864), first.date = "01/01/2004")
ipca <-
ipca %>%
mutate(series.name =
case_when(id.num == 433 ~ "ipca",
id.num == 4449 ~ "administrados",
id.num == 10844 ~ "serviços",
id.num == 11428 ~ "livres",
id.num == 27863 ~ "industriais",
id.num == 27864 ~ "alimentos",
TRUE ~ series.name))
ipca <-
ipca %>%
select(data = ref.date, valor = value, series.name) %>%
pivot_wider(names_from = "series.name", values_from = "valor")
ipca_tsb <-
ipca %>%
mutate(data = yearmonth(data)) %>%
arrange(data) %>%
as_tsibble()
#============================================================#
#fable and future: Time series cross validation forecast ----
#============================================================#
ipca_fable <-
ipca_tsb %>%
stretch_tsibble(.step = 1, .init = 144)
model_list <- list()
pdqPDQ <- expand.grid(p = 0:4, d = 0, q = 0:4, P = 0:2, D = 0:1, Q = 0:2)
plan(multisession)
for (i in 1:nrow(pdqPDQ)) {
print(pdqPDQ[i,])
#constante incluída
model_list[[i]] <-
ipca_fable %>%
model(ARIMA(alimentos ~ 1 + pdq(pdqPDQ[i, 1], pdqPDQ[i, 2], pdqPDQ[i, 3]) +
PDQ(pdqPDQ[i, 4], pdqPDQ[i, 5], pdqPDQ[i, 6]))) %>%
forecast(h = 18) %>%
group_by(.id) %>%
mutate(h = row_number()) %>%
ungroup() %>%
#accuracy requer classe fable
as_fable(response = "alimentos", distribution = alimentos)
}
I don't like answer my own question, but it which was too long for a comment. May not be the most elegant solution, but I could solve the error (Error: object 'pdqPDQ' not found) using the listenv package which allowed me supply the pdqPDQ object inside the brackets.
model_list <- listenv()
for (i in 1:nrow(pdqPDQ)) {
#constante incluída
model_list[[i]] %<-% {
pdqPDQ;
ipca_fable %>%
model(ARIMA(alimentos ~ 1 + pdq(pdqPDQ[i, 1], pdqPDQ[i, 2], pdqPDQ[i, 3]) +
PDQ(pdqPDQ[i, 4], pdqPDQ[i, 5], pdqPDQ[i, 6]), method = "ML")) %>%
rename_with(~c(".id", paste0(pdqPDQ[i,], collapse = ", "))) %>%
forecast(h = 18) %>%
group_by(.id) %>%
mutate(h = row_number()) %>%
ungroup() %>%
#aparentemente, estruturalmente não muda nada, mas accuracy requer classe fable
as_fable(response = "alimentos", distribution = alimentos)
}
Given the following code
library(tidyverse)
library(lubridate)
library(tidymodels)
library(ranger)
df <- read_csv("https://raw.githubusercontent.com/norhther/datasets/main/bitcoin.csv")
df <- df %>%
mutate(Date = dmy(Date),
Change_Percent = str_replace(Change_Percent, "%", ""),
Change_Percent = as.double(Change_Percent)
) %>%
filter(year(Date) > 2017)
int <- interval(ymd("2020-01-20"),
ymd("2022-01-15"))
df <- df %>%
mutate(covid = ifelse(Date %within% int, T, F))
df %>%
ggplot(aes(x = Date, y = Price, color = covid)) +
geom_line()
df <- df %>%
arrange(Date) %>%
mutate(lag1 = lag(Price),
lag2 = lag(lag1),
lag3 = lag(lag2),
profit_next_day = lead(Profit))
# modelatge
df_mod <- df %>%
select(-covid, -Date, -Vol_K, -Profit) %>%
mutate(profit_next_day = as.factor(profit_next_day))
set.seed(42)
data_split <- initial_split(df_mod) # 3/4
train_data <- training(data_split)
test_data <- testing(data_split)
bitcoin_rec <-
recipe(profit_next_day ~ ., data = train_data) %>%
step_naomit(all_outcomes(), all_predictors()) %>%
step_normalize(all_numeric_predictors())
bitcoin_prep <-
prep(bitcoin_rec)
bitcoin_train <- juice(bitcoin_prep)
bitcoin_test <- bake(bitcoin_prep, test_data)
rf_spec <-
rand_forest(trees = 200) %>%
set_engine("ranger", importance = "impurity") %>%
set_mode("classification")
bitcoin_wflow <-
workflow() %>%
add_model(rf_spec) %>%
add_recipe(bitcoin_prep)
bitcoin_fit <-
bitcoin_wflow %>%
fit(data = train_data)
final_model <- last_fit(bitcoin_wflow, data_split)
collect_metrics(final_model)
final_model %>%
extract_workflow() %>%
predict(test_data)
The last chunk of code that extracts the workflow and predicts the test_data is throwing the error:
Error in stop_subscript(): ! Can't subset columns that don't exist.
x Column profit_next_day doesn't exist.
but profit_next_day exists already in test_data, as I checked multiple times, so I don't know what is happening. Never had this error before working with tidymodels.
The problem here comes from using step_naomit() on the outcome. In general, steps that change rows (such as removing them) can be pretty tricky when it comes time to resample or predict on new data. You can read more in detail in our book, but I would suggest that you remove step_naomit() altogether from your recipe and change your earlier code to:
df_mod <- df %>%
select(-covid, -Date, -Vol_K, -Profit) %>%
mutate(profit_next_day = as.factor(profit_next_day)) %>%
na.omit()
As I want to produce some visualizations and analysis on forecasted data outside the modeltime framework, I need to extract confidence values, fitted values and maybe also residuals.
The documentation indicates, that I need to use the function modeltime_calibrate() to get the confidence values and residuals. So one question would be, where do I extract the fitted values from?
My main question is whatsoever, how to do calibration on recursive ensembles. For any non-ensemble model I was able to do it, but in case of recursive ensembles I encounter some error messages, if I want to calibrate.
To illustrate the problem, look at the example code below, which ends up failing to calibrate all models:
library(modeltime.ensemble)
library(modeltime)
library(tidymodels)
library(earth)
library(glmnet)
library(xgboost)
library(tidyverse)
library(lubridate)
library(timetk)
FORECAST_HORIZON <- 24
m4_extended <- m4_monthly %>%
group_by(id) %>%
future_frame(
.length_out = FORECAST_HORIZON,
.bind_data = TRUE
) %>%
ungroup()
lag_transformer_grouped <- function(data){
data %>%
group_by(id) %>%
tk_augment_lags(value, .lags = 1:FORECAST_HORIZON) %>%
ungroup()
}
m4_lags <- m4_extended %>%
lag_transformer_grouped()
test_data <- m4_lags %>%
group_by(id) %>%
slice_tail(n = 12) %>%
ungroup()
train_data <- m4_lags %>%
drop_na()
future_data <- m4_lags %>%
filter(is.na(value))
model_fit_glmnet <- linear_reg(penalty = 1) %>%
set_engine("glmnet") %>%
fit(value ~ ., data = train_data)
model_fit_xgboost <- boost_tree("regression", learn_rate = 0.35) %>%
set_engine("xgboost") %>%
fit(value ~ ., data = train_data)
recursive_ensemble_panel <- modeltime_table(
model_fit_glmnet,
model_fit_xgboost
) %>%
ensemble_weighted(loadings = c(4, 6)) %>%
recursive(
transform = lag_transformer_grouped,
train_tail = panel_tail(train_data, id, FORECAST_HORIZON),
id = "id"
)
model_tbl <- modeltime_table(
recursive_ensemble_panel
)
calibrated_mod <- model_tbl %>%
modeltime_calibrate(test_data, id = "id", quiet = FALSE)
model_tbl %>%
modeltime_forecast(
new_data = future_data,
actual_data = m4_lags,
keep_data = TRUE
) %>%
group_by(id) %>%
plot_modeltime_forecast(
.interactive = FALSE,
.conf_interval_show = TRUE,
.facet_ncol = 2
)
The problem lies in your recursive_ensemble_panel. You have to do the recursive part on the models themselves and not the ensemble. Like you I would have expected to do the recursive in one go, maybe via modeltime_table.
# start of changes to your code.
# added recursive to the model
model_fit_glmnet <- linear_reg(penalty = 1) %>%
set_engine("glmnet") %>%
fit(value ~ ., data = train_data) %>%
recursive(
transform = lag_transformer_grouped,
train_tail = panel_tail(train_data, id, FORECAST_HORIZON),
id = "id"
)
# added recursive to the model
model_fit_xgboost <- boost_tree("regression", learn_rate = 0.35) %>%
set_engine("xgboost") %>%
fit(value ~ ., data = train_data) %>%
recursive(
transform = lag_transformer_grouped,
train_tail = panel_tail(train_data, id, FORECAST_HORIZON),
id = "id"
)
# removed recursive part
recursive_ensemble_panel <- modeltime_table(
model_fit_glmnet,
model_fit_xgboost
) %>%
ensemble_weighted(loadings = c(4, 6))
# rest of your code
I had to do some experimentation to find the right way to extract what I need (confidence intervals and residuals).
As you can see from the example code below, there needs to be a change in the models workflow to achieve this. Recursion needs to appear in the workflow object definition and neither in the model nor in the ensemble fit/specification.
I still have to do some tests here, but I guess, that I got what I need now:
# Time Series ML
library(tidymodels)
library(modeltime)
library(modeltime.ensemble)
# Core
library(tidyverse)
library(timetk)
# data def
FORECAST_HORIZON <- 24
lag_transformer_grouped <- function(m750){
m750 %>%
group_by(id) %>%
tk_augment_lags(value, .lags = 1:FORECAST_HORIZON) %>%
ungroup()
}
m750_lags <- m750 %>%
lag_transformer_grouped()
test_data <- m750_lags %>%
group_by(id) %>%
slice_tail(n = 12) %>%
ungroup()
train_data <- m750_lags %>%
drop_na()
future_data <- m750_lags %>%
filter(is.na(value))
# rec
recipe_spec <- recipe(value ~ date, train_data) %>%
step_timeseries_signature(date) %>%
step_rm(matches("(.iso$)|(.xts$)")) %>%
step_normalize(matches("(index.num$)|(_year$)")) %>%
step_dummy(all_nominal()) %>%
step_fourier(date, K = 1, period = 12)
recipe_spec %>% prep() %>% juice()
# elnet
model_fit_glmnet <- linear_reg(penalty = 1) %>%
set_engine("glmnet")
wflw_fit_glmnet <- workflow() %>%
add_model(model_fit_glmnet) %>%
add_recipe(recipe_spec %>% step_rm(date)) %>%
fit(train_data) %>%
recursive(
transform = lag_transformer_grouped,
train_tail = panel_tail(train_data, id, FORECAST_HORIZON),
id = "id"
)
# xgboost
model_fit_xgboost <- boost_tree("regression", learn_rate = 0.35) %>%
set_engine("xgboost")
wflw_fit_xgboost <- workflow() %>%
add_model(model_fit_xgboost) %>%
add_recipe(recipe_spec %>% step_rm(date)) %>%
fit(train_data) %>%
recursive(
transform = lag_transformer_grouped,
train_tail = panel_tail(train_data, id, FORECAST_HORIZON),
id = "id"
)
# mtbl
m750_models <- modeltime_table(
wflw_fit_xgboost,
wflw_fit_glmnet
)
# mfit
ensemble_fit <- m750_models %>%
ensemble_average(type = "mean")
# mcalib
calibration_tbl <- modeltime_table(
ensemble_fit
) %>%
modeltime_calibrate(test_data)
# residuals
calib_out <- calibration_tbl$.calibration_data[[1]] %>%
left_join(test_data %>% select(id, date, value))
# Forecast ex post
ex_post_obj <-
calibration_tbl %>%
modeltime_forecast(
new_data = test_data,
actual_data = m750
)
# Forecast ex ante
data_prepared_tbl <- bind_rows(train_data, test_data)
future_tbl <- data_prepared_tbl %>%
group_by(id) %>%
future_frame(.length_out = "2 years") %>%
ungroup()
ex_ante_obj <-
calibration_tbl %>%
modeltime_forecast(
new_data = future_tbl,
actual_data = m750
)
I'm trying to view how this model performs against prior actual close. I'm using a workflow_set model and have no issues extracting the forecast. I've supplied a reproducible example below. I'd like to be able to plot actual, with a backtested trend line along with the forecast.
tickers <- "TSLA"
first.date <- Sys.Date() - 3000
last.date <- Sys.Date()
freq.data <- "daily"
stocks <- BatchGetSymbols::BatchGetSymbols(tickers = tickers,
first.date = first.date,
last.date = last.date,
freq.data = freq.data ,
do.cache = FALSE,
thresh.bad.data = 0)
stocks <- stocks %>% as.data.frame() %>% select(Date = df.tickers.ref.date, Close = df.tickers.price.close)
time_val_split <-
stocks %>%
sliding_period(
Date,
period = "day",
every = 52)
data_extended <- stocks %>%
future_frame(
.length_out = 60,
.bind_data = TRUE
) %>%
ungroup()
train_tbl <- data_extended %>% drop_na()
future_tbl <- data_extended %>% filter(is.na(Close))
base_rec <- recipe(Close ~ Date, train_tbl) %>%
step_timeseries_signature(Date) %>%
step_rm(matches("(.xts$)|(.iso$)|(.lbl)|(hour)|(minute)|(second)|(am.pm)|(mweek)|(qday)|(week2)|(week3)|(week4)")) %>%
step_dummy(all_nominal(), one_hot = TRUE) %>%
step_normalize(all_numeric_predictors()) %>%
step_scale(all_numeric_predictors()) %>%
step_rm(Date)
cubist_spec <-
cubist_rules(committees = tune(),
neighbors = tune()) %>%
set_engine("Cubist")
rf_spec <-
rand_forest(mtry = tune(),
min_n = tune(),
trees = 1000) %>%
set_engine("ranger") %>%
set_mode("regression")
base <-
workflow_set(
preproc = list(base_date = base_rec),
models = list(
cubist_base = cubist_spec,
cart_base = cart_spec
))
all_workflows <-
bind_rows(
base
)
cores <- parallel::detectCores(logical = FALSE)
clusters <- parallel::makePSOCKcluster(cores)
doParallel::registerDoParallel(clusters)
wflwset_tune_results <-
all_workflows %>%
workflow_map(
fn = "tune_race_anova",
seed = 1,
resamples = time_val_split,
grid = 2,
verbose = TRUE)
doParallel::stopImplicitCluster()
best_for_each_mod <- wflwset_tune_results %>%
rank_results(select_best = TRUE) %>%
filter(.metric == "rmse") %>%
select(wflow_id, .config, mean, preprocessor, model)
b_mod <- best_for_each_mod %>%
arrange(mean) %>%
head(1) %>%
select(wflow_id) %>% as.character()
best_param <- wflwset_tune_results %>% extract_workflow_set_result(id = b_mod) %>% select_best(metric = "rmse")
# Finalize model with best param
best_finalized <- wflwset_tune_results %>%
extract_workflow(b_mod) %>%
finalize_workflow(best_param) %>%
fit(train_tbl)
At this point the model has been trained but I can't seem to figure out how to run it against prior actuals. My goal is to bind the backed results with the predictions below.
prediction_tbl <- best_finalized %>%
predict(new_data = future_tbl) %>%
bind_cols(future_tbl) %>%
select(.pred, Date) %>%
mutate(type = "prediction") %>%
rename(Close = .pred)
train_tbl %>% mutate(type = "actual") %>% rbind(prediction_tbl) %>%
ggplot(aes(Date, Close, color = type)) +
geom_line(size = 2)
Based on your comment, I'd recommend using pivot_longer() after binding the future_tbl to your predictions. This lets you keep everything in one pipeline, rather than having to create two separate dataframes then bind them together. Here's an example plotting the prediction & actual values against mpg. Hope this helps!
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
# split data
set.seed(123)
mtcars <- as_tibble(mtcars)
cars_split <- initial_split(mtcars)
cars_train <- training(cars_split)
cars_test <- testing(cars_split)
# plot truth & prediction against another variable
workflow() %>%
add_model(linear_reg() %>% set_engine("lm")) %>%
add_recipe(recipe(qsec ~ ., data = cars_train)) %>%
fit(cars_train) %>%
predict(cars_test) %>%
bind_cols(cars_test) %>%
pivot_longer(cols = c(.pred, qsec),
names_to = "comparison",
values_to = "value") %>%
ggplot(aes(x = mpg,
y = value,
color = comparison)) +
geom_point(alpha = 0.75)
Created on 2021-11-18 by the reprex package (v2.0.1)
In the following code, I want to replace map_dfr from purrr with one of the SparkR apply functions to parallelize the Shapley calculations on the azure databricks:
#install.packages("randomForest"); install.packages("tidyverse"); install.packages("iml"); install.packages(SparkR)
library(tidyverse); library(iml); library(randomForest); library(SparkR)
mtcars1 <- mtcars %>% mutate(vs = as.factor(vs), id = row_number())
x <- "vs"
y <- paste0(setdiff(setdiff(names(mtcars1), "vs"), "id"), collapse = "+")
rf = randomForest(as.formula(paste0(x, "~ ", y)), data = mtcars1, ntree = 50)
predictor <- Predictor$new(rf, data = mtcars1, y = mtcars1$vs)
shapelyresults <- map_dfr(1:nrow(mtcars), ~(Shapley$new(predictor, x.interest = mtcars1[.x,]) %>%
.$results %>%
as_tibble() %>%
arrange(desc(phi)) %>%
slice(1:5) %>%
select(feature.value, phi) %>%
mutate(id = .x)))
I could not leverage the answer on the following link: How to apply a function to each row in SparkR?