[Plumber][Tidymodels] Erro ao usar a função predict em um fit de classe workflow dentro do plumber

Olá, bom dia!

Estou com um modelo do tidymodels em produção no plumber que acontece um erro estranho…

Forma como o modelo é criado:

# Split data
set.seed(1221324)
df_train <- final_df %>%
  dplyr::filter(createdat <= today - n_test)

df_test <- final_df %>%
  dplyr::filter(createdat > today - n_test)

df_folds <- vfold_cv(df_train, strata = "make")

# Grid specification
xgb_grid <- dials::grid_latin_hypercube(
  tree_depth(),
  min_n(),
  loss_reduction(),
  sample_size = sample_prop(),
  finalize(mtry(), df_train),
  learn_rate(),
  size = 40
)

# Workflow
xgb_wf <- workflows::workflow() %>%
  workflows::add_formula(price ~ .) %>%
  workflows::add_model(xgb_spec)

# Tuning
doParallel::registerDoParallel()
xgb_res <- tune::tune_grid(
  xgb_wf,
  resamples = df_folds,
  grid = xgb_grid,
  control = tune::control_grid(save_pred = TRUE)
)

best_rmse <- select_best(xgb_res, "rmse")

model <- finalize_workflow(
  xgb_wf,
  best_rmse
)

model_fitted= fit(model, df_train)

E então eu salvo com a função saveRDS tanto os objetos model e model_fitted.

Se eu abrir uma nova sessão do R, eu consigo:

  • Ler o model_fitted.rds e já rodar stats::predict(model_fitted, test)
  • Ler o model.rds, rodar a função fit(model, df_train) e depois rodar stats::predict()

Porém, quando eu tento ler esses mesmos dois arquivos em uma api minha no Plumber.

  • Eu consigo ler o model.rds, rodar a função fit(model, df_train) e a partir daí rodar as predições com stats::predict()
  • Eu não consigo utilizar o arquivo já pronto model_fitted.rds para fazer as predições.

E o erro que aparece é:

<simpleError in UseMethod("predict"): método não aplicável para 'predict' aplicado a um objeto de classe "workflow">

E gostaria de conseguir resolver isso, pois a primeira opção faz a minha máquina virtual demorar em média 10 a 15 minutos para ligar ela (já que precisa treinar o modelo), o que fica bem complicado em dias que talvez tenha uma instabilidade de uso nela, e tenha que reiniciar a máquina após um pico de consumo.

Alguém tem alguma ideia de como fazer funcionar o arquivo model_fitted.rds no plumber? Para que eu não precise treinar o modelo (com a função fit())toda vez que ligo a api?

Obrigado.

André! to testando aqui, mas enquanto isso só pra verificar, vc carregou o library(workflows) no plumber?

só pra ilustrar a minha suspeita, o predict() é uma função genérica. Parece q ele não tá encontrando a função workflows:::predict.workflow() que é a função q ele usaria pra fazer o predict de um objeto workflow.

Veja se isso resolve!

1 curtida

@Athos você é perfeito cara!

Era exatamente isso.

Como eu tinha o pacote workflows carregado na minha sessão R, e eu conseguia rodar o predict com stats::predict eu achava que era esse predict mesmo. E então usava apenas essa diretamente no plumber.

Mas tentei agora aqui com workflows:::predict.workflow e rodou perfeitamente.

Muito obrigado Athos!

Edit: Só fiquei curioso o porquê da minha sessão R dar certo, mas a minha máquina no plumber não…
Ambas eu dou um library(workflows) no início, e ambas eu chamo a previsão com stats::predict(). Só que na sessão do R rodava tranquila.

q bom q deu!

E tbm achei curioso. concordo que era pra ter dado certo, msm!. Acho que na hora da prod, o melhor é tentar lotar de :: ou ::: pra evitar esses comportamentos esquisitos.