我想使用ggplot2and復制 plot.lda 列印方法tidymodels。有沒有一種優雅的方式來獲得情節?
我想我可以augment()通過使用predict()它并將其系結到原始數??據來偽造沒有 lda 方法的函式。
這是一個帶有基本 R 和tidymodels代碼的示例:
library(ISLR2)
library(MASS)
# First base R
train <- Smarket$Year < 2005
lda.fit <-
lda(
Direction ~ Lag1 Lag2,
data = Smarket,
subset = train
)
plot(lda.fit)
# Next tidymodels
library(tidyverse)
library(tidymodels)
library(discrim)
lda_spec <- discrim_linear() %>%
set_mode("classification") %>%
set_engine("MASS")
the_rec <- recipe(
Direction ~ Lag1 Lag2,
data = Smarket
)
the_workflow<- workflow() %>%
add_recipe(the_rec) %>%
add_model(lda_spec)
Smarket_train <- Smarket %>%
filter(Year != 2005)
the_workflow_fit_lda_fit <-
fit(the_workflow, data = Smarket_train) %>%
extract_fit_parsnip()
# now my attempt to do the plot
predictions <- predict(the_workflow_fit_lda_fit,
new_data = Smarket_train,
type = "raw"
)[[3]] %>%
as.vector()
bind_cols(Smarket_train, .fitted = predictions) %>%
ggplot(aes(x=.fitted))
geom_histogram(aes(y = stat(density)),binwidth = .5)
scale_x_continuous(breaks = seq(-4, 4, by = 2))
facet_grid(vars(Direction))
xlab("")
ylab("Density")
必須有更好的方法來做到這一點......想法?
uj5u.com熱心網友回復:
您可以通過使用組合做extract_fit_*()和parsnip:::repair_call()。該plot.lda()方法使用$callLDA fit 中的物件,我們需要對其進行調整,因為使用 tidymodels 的呼叫物件將與lda()直接使用不同。
library(ISLR2)
library(MASS)
# First base R
train <- Smarket$Year < 2005
lda.fit <-
lda(
Direction ~ Lag1 Lag2,
data = Smarket,
subset = train
)
# Next tidymodels
library(tidyverse)
library(tidymodels)
library(discrim)
lda_spec <- discrim_linear() %>%
set_mode("classification") %>%
set_engine("MASS")
the_rec <- recipe(
Direction ~ Lag1 Lag2,
data = Smarket
)
the_workflow <- workflow() %>%
add_recipe(the_rec) %>%
add_model(lda_spec)
Smarket_train <- Smarket %>%
filter(Year != 2005)
the_workflow_fit_lda_fit <-
fit(the_workflow, data = Smarket_train)
擬合兩個模型后,我們可以檢查$call物件,我們發現它們是不同的。
lda.fit$call
#> lda(formula = Direction ~ Lag1 Lag2, data = Smarket, subset = train)
extract_fit_engine(the_workflow_fit_lda_fit)$call
#> lda(formula = ..y ~ ., data = data)
該parsnip::repair_call()函式將替換data為我們傳入的資料。此外,我們將資料的回應重命名為..y以匹配呼叫。
the_workflow_fit_lda_fit %>%
extract_fit_parsnip() %>%
parsnip::repair_call(rename(Smarket_train, ..y = Direction)) %>%
extract_fit_engine() %>%
plot()

由reprex 包(v2.0.1)于 2021 年 11 月 12 日創建
轉載請註明出處,本文鏈接:https://www.uj5u.com/gongcheng/358903.html
