Title: | 'ggplot' Visualizations for the 'partykit' Package |
---|---|
Description: | Extends 'ggplot2' functionality to the 'partykit' package. 'ggparty' provides the necessary tools to create clearly structured and highly customizable visualizations for tree-objects of the class 'party'. |
Authors: | Martin Borkovec [aut, cre], Niyaz Madin [aut], Hadley Wickham [ctb], Winston Chang [ctb], Lionel Henry [ctb], Thomas Lin Pedersen [ctb], Kohske Takahashi [ctb], Claus Wilke [ctb], Kara Woo [ctb], Hiroaki Yutani [ctb] |
Maintainer: | Martin Borkovec <[email protected]> |
License: | GPL-2 | GPL-3 |
Version: | 1.0.0 |
Built: | 2025-02-28 05:33:57 UTC |
Source: | https://github.com/martin-borkovec/ggparty |
autoplot methods for party objects
## S3 method for class 'party' autoplot(object, ...) ## S3 method for class 'constparty' autoplot(object, ...) ## S3 method for class 'modelparty' autoplot(object, plot_var = NULL, ...) ## S3 method for class 'lmtree' autoplot(object, plot_var = NULL, show_fit = TRUE, ...)
## S3 method for class 'party' autoplot(object, ...) ## S3 method for class 'constparty' autoplot(object, ...) ## S3 method for class 'modelparty' autoplot(object, plot_var = NULL, ...) ## S3 method for class 'lmtree' autoplot(object, plot_var = NULL, show_fit = TRUE, ...)
object |
object of class party. |
... |
additional parameters |
plot_var |
Which covariate to plot against response. Defaults to second
column in |
show_fit |
If TRUE |
library(ggparty) data("WeatherPlay", package = "partykit") sp_o <- partysplit(1L, index = 1:3) sp_h <- partysplit(3L, breaks = 75) sp_w <- partysplit(4L, index = 1:2) pn <- partynode(1L, split = sp_o, kids = list( partynode(2L, split = sp_h, kids = list( partynode(3L, info = "yes"), partynode(4L, info = "no"))), partynode(5L, info = "yes"), partynode(6L, split = sp_w, kids = list( partynode(7L, info = "yes"), partynode(8L, info = "no"))))) py <- party(pn, WeatherPlay) autoplot(py)
library(ggparty) data("WeatherPlay", package = "partykit") sp_o <- partysplit(1L, index = 1:3) sp_h <- partysplit(3L, breaks = 75) sp_w <- partysplit(4L, index = 1:2) pn <- partynode(1L, split = sp_o, kids = list( partynode(2L, split = sp_h, kids = list( partynode(3L, info = "yes"), partynode(4L, info = "no"))), partynode(5L, info = "yes"), partynode(6L, split = sp_w, kids = list( partynode(7L, info = "yes"), partynode(8L, info = "no"))))) py <- party(pn, WeatherPlay) autoplot(py)
Draws edges between children and parent nodes. Wrapper for ggplot2::geom_segment()
geom_edge(mapping = NULL, nudge_x = 0, nudge_y = 0, ids = NULL, show.legend = NA, ...)
geom_edge(mapping = NULL, nudge_x = 0, nudge_y = 0, ids = NULL, show.legend = NA, ...)
mapping |
Mapping of |
nudge_x , nudge_y
|
Nudge labels. |
ids |
Choose which edges to draw by their children's ids. |
show.legend |
|
... |
Additional arguments for |
library(ggparty) data("WeatherPlay", package = "partykit") sp_o <- partysplit(1L, index = 1:3) sp_h <- partysplit(3L, breaks = 75) sp_w <- partysplit(4L, index = 1:2) pn <- partynode(1L, split = sp_o, kids = list( partynode(2L, split = sp_h, kids = list( partynode(3L, info = "yes"), partynode(4L, info = "no"))), partynode(5L, info = "yes"), partynode(6L, split = sp_w, kids = list( partynode(7L, info = "yes"), partynode(8L, info = "no"))))) py <- party(pn, WeatherPlay) ggparty(py) + geom_edge() + geom_edge_label() + geom_node_label(aes(label = splitvar), ids = "inner") + geom_node_label(aes(label = info), ids = "terminal")
library(ggparty) data("WeatherPlay", package = "partykit") sp_o <- partysplit(1L, index = 1:3) sp_h <- partysplit(3L, breaks = 75) sp_w <- partysplit(4L, index = 1:2) pn <- partynode(1L, split = sp_o, kids = list( partynode(2L, split = sp_h, kids = list( partynode(3L, info = "yes"), partynode(4L, info = "no"))), partynode(5L, info = "yes"), partynode(6L, split = sp_w, kids = list( partynode(7L, info = "yes"), partynode(8L, info = "no"))))) py <- party(pn, WeatherPlay) ggparty(py) + geom_edge() + geom_edge_label() + geom_node_label(aes(label = splitvar), ids = "inner") + geom_node_label(aes(label = info), ids = "terminal")
Label edges with corresponding split breaks
geom_edge_label(mapping = NULL, nudge_x = 0, nudge_y = 0, ids = NULL, shift = 0.5, label.size = 0, splitlevels = seq_len(100), max_length = NULL, parse_all = FALSE, parse = TRUE, ...)
geom_edge_label(mapping = NULL, nudge_x = 0, nudge_y = 0, ids = NULL, shift = 0.5, label.size = 0, splitlevels = seq_len(100), max_length = NULL, parse_all = FALSE, parse = TRUE, ...)
mapping |
Mapping of |
nudge_x , nudge_y
|
Nudge label. |
ids |
Choose which splitbreaks to label by their children's ids. |
shift |
Value in (0,1). Moves label along corresponding edge. |
label.size |
See |
splitlevels |
Which levels of split to plot. This may be useful in the presence of many factor levels for one split break. |
max_length |
If provided breaks_label levels will be truncated to the specified length. |
parse_all |
Defaults to |
parse |
Needs to be true in order to parse inequality signs of breaks_label. |
... |
Additional arguments for |
library(ggparty) data("WeatherPlay", package = "partykit") sp_o <- partysplit(1L, index = 1:3) sp_h <- partysplit(3L, breaks = 75) sp_w <- partysplit(4L, index = 1:2) pn <- partynode(1L, split = sp_o, kids = list( partynode(2L, split = sp_h, kids = list( partynode(3L, info = "yes"), partynode(4L, info = "no"))), partynode(5L, info = "yes"), partynode(6L, split = sp_w, kids = list( partynode(7L, info = "yes"), partynode(8L, info = "no"))))) py <- party(pn, WeatherPlay) ggparty(py) + geom_edge() + geom_edge_label() + geom_node_label(aes(label = splitvar), ids = "inner") + geom_node_label(aes(label = info), ids = "terminal")
library(ggparty) data("WeatherPlay", package = "partykit") sp_o <- partysplit(1L, index = 1:3) sp_h <- partysplit(3L, breaks = 75) sp_w <- partysplit(4L, index = 1:2) pn <- partynode(1L, split = sp_o, kids = list( partynode(2L, split = sp_h, kids = list( partynode(3L, info = "yes"), partynode(4L, info = "no"))), partynode(5L, info = "yes"), partynode(6L, split = sp_w, kids = list( partynode(7L, info = "yes"), partynode(8L, info = "no"))))) py <- party(pn, WeatherPlay) ggparty(py) + geom_edge() + geom_edge_label() + geom_node_label(aes(label = splitvar), ids = "inner") + geom_node_label(aes(label = info), ids = "terminal")
geom_node_splitvar()
and geom_node_info()
are simplified versions of
geom_node_label()
with the respective defaults to either label the split variables
for all inner nodes or the info for all terminal nodes.
geom_node_label(mapping = NULL, data = NULL, line_list = NULL, line_gpar = NULL, ids = NULL, position = "identity", ..., parse = FALSE, nudge_x = 0, nudge_y = 0, label.padding = unit(0.25, "lines"), label.r = unit(0.15, "lines"), label.size = 0.25, label.col = NULL, label.fill = NULL, na.rm = FALSE, show.legend = NA, inherit.aes = TRUE) geom_node_info(mapping = NULL, nudge_x = 0, nudge_y = 0, ids = NULL, label.padding = unit(0.5, "lines"), ...) geom_node_splitvar(mapping = NULL, nudge_x = 0, nudge_y = 0, label.padding = unit(0.5, "lines"), ids = NULL, ...)
geom_node_label(mapping = NULL, data = NULL, line_list = NULL, line_gpar = NULL, ids = NULL, position = "identity", ..., parse = FALSE, nudge_x = 0, nudge_y = 0, label.padding = unit(0.25, "lines"), label.r = unit(0.15, "lines"), label.size = 0.25, label.col = NULL, label.fill = NULL, na.rm = FALSE, show.legend = NA, inherit.aes = TRUE) geom_node_info(mapping = NULL, nudge_x = 0, nudge_y = 0, ids = NULL, label.padding = unit(0.5, "lines"), ...) geom_node_splitvar(mapping = NULL, nudge_x = 0, nudge_y = 0, label.padding = unit(0.5, "lines"), ids = NULL, ...)
mapping |
|
data |
The data to be displayed in this layer. There are three options: If A A |
line_list |
Use this only if you want a multi-line label with the
possibility to override the aesthetics mapping for each line specifically
with fixed graphical parameters. In this case, don't map anything to
|
line_gpar |
List of lists containing line-specific graphical parameters.
Only use in
conjunction with |
ids |
Select for which nodes to draw a label. Can be |
position |
Position adjustment, either as a string, or the result of a call to a position adjustment function. |
... |
Additional arguments to layer. |
parse |
If |
nudge_x , nudge_y
|
Adjust position of label. |
label.padding |
Amount of padding around label. Defaults to 0.25 lines. |
label.r |
Radius of rounded corners. Defaults to 0.15 lines. |
label.size |
Size of label border, in mm. |
label.col |
Border colour. |
label.fill |
Background colour. |
na.rm |
If |
show.legend |
logical. Should this layer be included in the legends?
|
inherit.aes |
If |
geom_node_label()
is a modified version of ggplot2::geom_label()
. This
modification allows for labels with multiple lines and line specific graphical
parameters.
library(ggparty) data("WeatherPlay", package = "partykit") sp_o <- partysplit(1L, index = 1:3) sp_h <- partysplit(3L, breaks = 75) sp_w <- partysplit(4L, index = 1:2) pn <- partynode(1L, split = sp_o, kids = list( partynode(2L, split = sp_h, kids = list( partynode(3L, info = "yes"), partynode(4L, info = "no"))), partynode(5L, info = "yes"), partynode(6L, split = sp_w, kids = list( partynode(7L, info = "yes"), partynode(8L, info = "no"))))) py <- party(pn, WeatherPlay) ggparty(py) + geom_edge() + geom_edge_label() + geom_node_label(aes(label = splitvar), ids = "inner") + geom_node_label(aes(label = info), ids = "terminal") ###################################### data("TeachingRatings", package = "AER") tr <- subset(TeachingRatings, credits == "more") tr_tree <- lmtree(eval ~ beauty | minority + age + gender + division + native + tenure, data = tr, weights = students, caseweights = FALSE) data("TeachingRatings", package = "AER") tr <- subset(TeachingRatings, credits == "more") tr_tree <- lmtree(eval ~ beauty | minority + age + gender + division + native + tenure, data = tr, weights = students, caseweights = FALSE) ggparty(tr_tree, terminal_space = 0.5, add_vars = list(p.value = "$node$info$p.value")) + geom_edge(size = 1.5) + geom_edge_label(colour = "grey", size = 6) + geom_node_plot(gglist = list(geom_point(aes(x = beauty, y = eval, col = tenure, shape = minority), alpha = 0.8), theme_bw(base_size = 15)), scales = "fixed", id = "terminal", shared_axis_labels = TRUE, shared_legend = TRUE, legend_separator = TRUE, predict = "beauty", predict_gpar = list(col = "blue", size = 1.2) ) + geom_node_label(aes(col = splitvar), line_list = list(aes(label = paste("Node", id)), aes(label = splitvar), aes(label = paste("p =", formatC(p.value, format = "e", digits = 2)))), line_gpar = list(list(size = 12, col = "black", fontface = "bold"), list(size = 20), list(size = 12)), ids = "inner") + geom_node_label(aes(label = paste0("Node ", id, ", N = ", nodesize)), fontface = "bold", ids = "terminal", size = 5, nudge_y = 0.01) + theme(legend.position = "none")
library(ggparty) data("WeatherPlay", package = "partykit") sp_o <- partysplit(1L, index = 1:3) sp_h <- partysplit(3L, breaks = 75) sp_w <- partysplit(4L, index = 1:2) pn <- partynode(1L, split = sp_o, kids = list( partynode(2L, split = sp_h, kids = list( partynode(3L, info = "yes"), partynode(4L, info = "no"))), partynode(5L, info = "yes"), partynode(6L, split = sp_w, kids = list( partynode(7L, info = "yes"), partynode(8L, info = "no"))))) py <- party(pn, WeatherPlay) ggparty(py) + geom_edge() + geom_edge_label() + geom_node_label(aes(label = splitvar), ids = "inner") + geom_node_label(aes(label = info), ids = "terminal") ###################################### data("TeachingRatings", package = "AER") tr <- subset(TeachingRatings, credits == "more") tr_tree <- lmtree(eval ~ beauty | minority + age + gender + division + native + tenure, data = tr, weights = students, caseweights = FALSE) data("TeachingRatings", package = "AER") tr <- subset(TeachingRatings, credits == "more") tr_tree <- lmtree(eval ~ beauty | minority + age + gender + division + native + tenure, data = tr, weights = students, caseweights = FALSE) ggparty(tr_tree, terminal_space = 0.5, add_vars = list(p.value = "$node$info$p.value")) + geom_edge(size = 1.5) + geom_edge_label(colour = "grey", size = 6) + geom_node_plot(gglist = list(geom_point(aes(x = beauty, y = eval, col = tenure, shape = minority), alpha = 0.8), theme_bw(base_size = 15)), scales = "fixed", id = "terminal", shared_axis_labels = TRUE, shared_legend = TRUE, legend_separator = TRUE, predict = "beauty", predict_gpar = list(col = "blue", size = 1.2) ) + geom_node_label(aes(col = splitvar), line_list = list(aes(label = paste("Node", id)), aes(label = splitvar), aes(label = paste("p =", formatC(p.value, format = "e", digits = 2)))), line_gpar = list(list(size = 12, col = "black", fontface = "bold"), list(size = 20), list(size = 12)), ids = "inner") + geom_node_label(aes(label = paste0("Node ", id, ", N = ", nodesize)), fontface = "bold", ids = "terminal", size = 5, nudge_y = 0.01) + theme(legend.position = "none")
Additional component for a ggparty()
that allows to create in each node a
ggplot with its data. #'
geom_node_plot(plot_call = "ggplot", gglist = NULL, width = 1, height = 1, size = 1, ids = "terminal", scales = "fixed", nudge_x = 0, nudge_y = 0, shared_axis_labels = FALSE, shared_legend = TRUE, predict = NULL, predict_gpar = NULL, legend_separator = FALSE)
geom_node_plot(plot_call = "ggplot", gglist = NULL, width = 1, height = 1, size = 1, ids = "terminal", scales = "fixed", nudge_x = 0, nudge_y = 0, shared_axis_labels = FALSE, shared_legend = TRUE, predict = NULL, predict_gpar = NULL, legend_separator = FALSE)
plot_call |
Any function that generates a |
gglist |
List of additional |
width |
Expansion factor for viewport's width. |
height |
Expansion factor for viewport's height. |
size |
Expansion factor for viewport's size. |
ids |
Id's to plot. Numeric, "terminal", "inner" or "all". Defaults to "terminal". |
scales |
See |
nudge_x , nudge_y
|
Nudges node plot. |
shared_axis_labels |
If TRUE only one pair of axes labels is plotted in
the terminal space. Only recommended if |
shared_legend |
If |
predict |
Character string specifying variable for which predictions should be plotted. |
predict_gpar |
Named list containing arguments to be passed to the
|
legend_separator |
If |
library(ggparty) airq <- subset(airquality, !is.na(Ozone)) airct <- ctree(Ozone ~ ., data = airq) ggparty(airct, horizontal = TRUE, terminal_space = 0.6) + geom_edge() + geom_edge_label() + geom_node_splitvar() + geom_node_plot(gglist = list( geom_density(aes(x = Ozone))), shared_axis_labels = TRUE) ############################################################# ## Plot with ggparty ## Demand for economics journals data data("Journals", package = "AER") Journals <- transform(Journals, age = 2000 - foundingyear, chars = charpp * pages) ## linear regression tree (OLS) j_tree <- lmtree(log(subs) ~ log(price/citations) | price + citations + age + chars + society, data = Journals, minsize = 10, verbose = TRUE) pred_df <- get_predictions(j_tree, ids = "terminal", newdata = function(x) { data.frame( citations = 1, price = exp(seq(from = min(x$`log(price/citations)`), to = max(x$`log(price/citations)`), length.out = 100))) }) ggparty(j_tree, terminal_space = 0.8) + geom_edge() + geom_edge_label() + geom_node_splitvar() + geom_node_plot(gglist = list(aes(x = `log(price/citations)`, y = `log(subs)`), geom_point(), geom_line(data = pred_df, aes(x = log(price/citations), y = prediction), col = "red")))
library(ggparty) airq <- subset(airquality, !is.na(Ozone)) airct <- ctree(Ozone ~ ., data = airq) ggparty(airct, horizontal = TRUE, terminal_space = 0.6) + geom_edge() + geom_edge_label() + geom_node_splitvar() + geom_node_plot(gglist = list( geom_density(aes(x = Ozone))), shared_axis_labels = TRUE) ############################################################# ## Plot with ggparty ## Demand for economics journals data data("Journals", package = "AER") Journals <- transform(Journals, age = 2000 - foundingyear, chars = charpp * pages) ## linear regression tree (OLS) j_tree <- lmtree(log(subs) ~ log(price/citations) | price + citations + age + chars + society, data = Journals, minsize = 10, verbose = TRUE) pred_df <- get_predictions(j_tree, ids = "terminal", newdata = function(x) { data.frame( citations = 1, price = exp(seq(from = min(x$`log(price/citations)`), to = max(x$`log(price/citations)`), length.out = 100))) }) ggparty(j_tree, terminal_space = 0.8) + geom_edge() + geom_edge_label() + geom_node_splitvar() + geom_node_plot(gglist = list(aes(x = `log(price/citations)`, y = `log(subs)`), geom_point(), geom_line(data = pred_df, aes(x = log(price/citations), y = prediction), col = "red")))
Create data.frame with predictions for each node
get_predictions(party_object, ids, newdata_fun, predict_arg = NULL)
get_predictions(party_object, ids, newdata_fun, predict_arg = NULL)
party_object |
object of class |
ids |
Id's to plot. Numeric, "terminal", "inner" or "all". MUST be identical
to |
newdata_fun |
function which takes |
predict_arg |
list of additional arguments passed to |
ggplot2
extension for objects of class party
. Creates a data.frame
from
an object of class party
and calls ggplot()
ggparty(party, horizontal = FALSE, terminal_space, layout = NULL, add_vars = NULL)
ggparty(party, horizontal = FALSE, terminal_space, layout = NULL, add_vars = NULL)
party |
Object of class |
horizontal |
If |
terminal_space |
Proportion of the plot that should be reserved for
the terminal nodeplots. Defaults to |
layout |
Optional layout adjustment. Overwrites the coordinates of the
specified nodes. Must be |
add_vars |
Named list containing either string(s) specifying the locations
of elements to be extracted from
each node of |
ggparty
can be called directly with an object of class party
, which will
convert it to a suitable data.frame
and pass it to a call to ggplot
with as
the data
argument. As usual, additional components can then be added with
+
.
The nodes will be spaced equally in the unit square. Specifying
terminal_size
allows to increase or decrease the area for plots of the
terminal nodes.
If one of the list entries supplied to add_vars
is a function, it has to take
exactly two arguments,
namely data
(the corresponding row of the plot_data data frame) and node
(the corresponding node, i.e. party_object[i]
)
geom_edge()
, geom_edge_label()
, geom_node_label()
,
autoplot.party()
, geom_node_plot()
library(ggparty) data("WeatherPlay", package = "partykit") sp_o <- partysplit(1L, index = 1:3) sp_h <- partysplit(3L, breaks = 75) sp_w <- partysplit(4L, index = 1:2) pn <- partynode(1L, split = sp_o, kids = list( partynode(2L, split = sp_h, kids = list( partynode(3L, info = "yes"), partynode(4L, info = "no"))), partynode(5L, info = "yes"), partynode(6L, split = sp_w, kids = list( partynode(7L, info = "yes"), partynode(8L, info = "no"))))) py <- party(pn, WeatherPlay) ggparty(py) + geom_edge() + geom_edge_label() + geom_node_label(aes(label = splitvar), ids = "inner") + geom_node_label(aes(label = info), ids = "terminal")
library(ggparty) data("WeatherPlay", package = "partykit") sp_o <- partysplit(1L, index = 1:3) sp_h <- partysplit(3L, breaks = 75) sp_w <- partysplit(4L, index = 1:2) pn <- partynode(1L, split = sp_o, kids = list( partynode(2L, split = sp_h, kids = list( partynode(3L, info = "yes"), partynode(4L, info = "no"))), partynode(5L, info = "yes"), partynode(6L, split = sp_w, kids = list( partynode(7L, info = "yes"), partynode(8L, info = "no"))))) py <- party(pn, WeatherPlay) ggparty(py) + geom_edge() + geom_edge_label() + geom_node_label(aes(label = splitvar), ids = "inner") + geom_node_label(aes(label = info), ids = "terminal")
apparantly needs to be exported
## S3 method for class 'nodeplotgrob' makeContent(x)
## S3 method for class 'nodeplotgrob' makeContent(x)
x |
nodeplotgrob |