Olá, bom dia!
Estou com um modelo do tidymodels em produção no plumber que acontece um erro estranho…
Forma como o modelo é criado:
# Split data
set.seed(1221324)
df_train <- final_df %>%
dplyr::filter(createdat <= today - n_test)
df_test <- final_df %>%
dplyr::filter(createdat > today - n_test)
df_folds <- vfold_cv(df_train, strata = "make")
# Grid specification
xgb_grid <- dials::grid_latin_hypercube(
tree_depth(),
min_n(),
loss_reduction(),
sample_size = sample_prop(),
finalize(mtry(), df_train),
learn_rate(),
size = 40
)
# Workflow
xgb_wf <- workflows::workflow() %>%
workflows::add_formula(price ~ .) %>%
workflows::add_model(xgb_spec)
# Tuning
doParallel::registerDoParallel()
xgb_res <- tune::tune_grid(
xgb_wf,
resamples = df_folds,
grid = xgb_grid,
control = tune::control_grid(save_pred = TRUE)
)
best_rmse <- select_best(xgb_res, "rmse")
model <- finalize_workflow(
xgb_wf,
best_rmse
)
model_fitted= fit(model, df_train)
E então eu salvo com a função saveRDS
tanto os objetos model
e model_fitted
.
Se eu abrir uma nova sessão do R, eu consigo:
- Ler o model_fitted.rds e já rodar
stats::predict(model_fitted, test)
- Ler o model.rds, rodar a função
fit(model, df_train)
e depois rodarstats::predict()
Porém, quando eu tento ler esses mesmos dois arquivos em uma api minha no Plumber.
- Eu consigo ler o model.rds, rodar a função
fit(model, df_train)
e a partir daí rodar as predições comstats::predict()
- Eu não consigo utilizar o arquivo já pronto model_fitted.rds para fazer as predições.
E o erro que aparece é:
<simpleError in UseMethod("predict"): método não aplicável para 'predict' aplicado a um objeto de classe "workflow">
E gostaria de conseguir resolver isso, pois a primeira opção faz a minha máquina virtual demorar em média 10 a 15 minutos para ligar ela (já que precisa treinar o modelo), o que fica bem complicado em dias que talvez tenha uma instabilidade de uso nela, e tenha que reiniciar a máquina após um pico de consumo.
Alguém tem alguma ideia de como fazer funcionar o arquivo model_fitted.rds
no plumber? Para que eu não precise treinar o modelo (com a função fit()
)toda vez que ligo a api?
Obrigado.