The simple_kmeans_db()
function enables running the
KMeans model inside the database. It uses dplyr
programming
to abstract the steps needed produce a model, so that it can then be
translated into SQL statements in the background.
In this example, a simple RSQlite
database will be use
to load the flights
data from the nycflights13
library.
library(dplyr)
<- DBI::dbConnect(RSQLite::SQLite(), path = ":memory:")
con ::initExtension(con)
RSQLite
<- copy_to(con, nycflights13::flights, "flights") db_flights
The function simple_kmeans_db()
can use with local data,
or a remote table, such as the db_flights
variable that is
a pointer to the “flights” table inside the SQLite database. When piping
to the function, the only other required arguments are two or more
fields separated by comma. Because it uses ‘tidyeval’, the variable name
auto-completion will work.
library(modeldb)
<- db_flights %>%
km simple_kmeans_db(dep_time, distance)
The simple_kmeans_db()
function uses a progress bar to
show you the current cycle, the maximum cycles it’s expected to run, the
current difference between the previous cycle and the current cycle, and
the running time. The loop will stop once it wither has two matching
consecutive cycles, or if it reaches the maximum number of cycles, as
determined by the max_repeats
argument.
The final centers are are stored in the
centers
variable of the returned object
$centers
km#> NULL
The latest results are stored in the tbl
variable of the
returned object. The type of the returned table will match the type of
the original source, so if it is a remote source, such as database
table, then tbl
will be a class tbl_sql
. This
will allow us to do two thing:
::remote_query(km)
dbplyr#> <SQL> SELECT `k_center`, `k_dep_time`, `k_distance`, `year`, `month`, `day`, `dep_time`, `sched_dep_time`, `dep_delay`, `arr_time`, `sched_arr_time`, `arr_delay`, `carrier`, `flight`, `tailnum`, `origin`, `dest`, `air_time`, `distance`, `hour`, `minute`, `time_hour`
#> FROM (SELECT `year`, `month`, `day`, `dep_time`, `sched_dep_time`, `dep_delay`, `arr_time`, `sched_arr_time`, `arr_delay`, `carrier`, `flight`, `tailnum`, `origin`, `dest`, `air_time`, `distance`, `hour`, `minute`, `time_hour`, `LHS`.`k_center` AS `k_center`, `k_dep_time`, `k_distance`
#> FROM (SELECT `year`, `month`, `day`, `dep_time`, `sched_dep_time`, `dep_delay`, `arr_time`, `sched_arr_time`, `arr_delay`, `carrier`, `flight`, `tailnum`, `origin`, `dest`, `air_time`, `distance`, `hour`, `minute`, `time_hour`, `center` AS `k_center`
#> FROM (SELECT `year`, `month`, `day`, `dep_time`, `sched_dep_time`, `dep_delay`, `arr_time`, `sched_arr_time`, `arr_delay`, `carrier`, `flight`, `tailnum`, `origin`, `dest`, `air_time`, `distance`, `hour`, `minute`, `time_hour`, `center_1`, `center_2`, `center_3`, CASE
#> WHEN (`center_1` >= `center_1` AND `center_1` < `center_2` AND `center_1` < `center_3`) THEN ('center_1')
#> WHEN (`center_2` < `center_1` AND `center_2` >= `center_2` AND `center_2` < `center_3`) THEN ('center_2')
#> WHEN (`center_3` < `center_1` AND `center_3` < `center_2` AND `center_3` >= `center_3`) THEN ('center_3')
#> END AS `center`
#> FROM (SELECT `year`, `month`, `day`, `dep_time`, `sched_dep_time`, `dep_delay`, `arr_time`, `sched_arr_time`, `arr_delay`, `carrier`, `flight`, `tailnum`, `origin`, `dest`, `air_time`, `distance`, `hour`, `minute`, `time_hour`, SQRT(((889.757881651311 - `dep_time`) * (889.757881651311 - `dep_time`)) + ((791.286862996562 - `distance`) * (791.286862996562 - `distance`))) AS `center_1`, SQRT(((1391.08534916316 - `dep_time`) * (1391.08534916316 - `dep_time`)) + ((2355.04462033144 - `distance`) * (2355.04462033144 - `distance`))) AS `center_2`, SQRT(((1745.74853136521 - `dep_time`) * (1745.74853136521 - `dep_time`)) + ((718.043515631104 - `distance`) * (718.043515631104 - `distance`))) AS `center_3`
#> FROM `flights`))
#> WHERE (NOT(((`center`) IS NULL)))) AS `LHS`
#> LEFT JOIN (SELECT `center` AS `k_center`, `dep_time` AS `k_dep_time`, `distance` AS `k_distance`
#> FROM (SELECT `center`, AVG(`dep_time`) AS `dep_time`, AVG(`distance`) AS `distance`
#> FROM (SELECT `dep_time`, `distance`, `center`
#> FROM (SELECT `year`, `month`, `day`, `dep_time`, `sched_dep_time`, `dep_delay`, `arr_time`, `sched_arr_time`, `arr_delay`, `carrier`, `flight`, `tailnum`, `origin`, `dest`, `air_time`, `distance`, `hour`, `minute`, `time_hour`, `center_1`, `center_2`, `center_3`, CASE
#> WHEN (`center_1` >= `center_1` AND `center_1` < `center_2` AND `center_1` < `center_3`) THEN ('center_1')
#> WHEN (`center_2` < `center_1` AND `center_2` >= `center_2` AND `center_2` < `center_3`) THEN ('center_2')
#> WHEN (`center_3` < `center_1` AND `center_3` < `center_2` AND `center_3` >= `center_3`) THEN ('center_3')
#> END AS `center`
#> FROM (SELECT `year`, `month`, `day`, `dep_time`, `sched_dep_time`, `dep_delay`, `arr_time`, `sched_arr_time`, `arr_delay`, `carrier`, `flight`, `tailnum`, `origin`, `dest`, `air_time`, `distance`, `hour`, `minute`, `time_hour`, SQRT(((889.757881651311 - `dep_time`) * (889.757881651311 - `dep_time`)) + ((791.286862996562 - `distance`) * (791.286862996562 - `distance`))) AS `center_1`, SQRT(((1391.08534916316 - `dep_time`) * (1391.08534916316 - `dep_time`)) + ((2355.04462033144 - `distance`) * (2355.04462033144 - `distance`))) AS `center_2`, SQRT(((1745.74853136521 - `dep_time`) * (1745.74853136521 - `dep_time`)) + ((718.043515631104 - `distance`) * (718.043515631104 - `distance`))) AS `center_3`
#> FROM `flights`))
#> WHERE (NOT(((`center`) IS NULL))))
#> GROUP BY `center`)) AS `RHS`
#> ON (`LHS`.`k_center` = `RHS`.`k_center`)
#> )
The simple_kmeans_db()
function uses dplyr
and ‘tidyeval’ to run the KMeans algorithm. This means that when
combined with dbplyr
, the routines can be run inside a
database.
Unlike other packages that use this same methodology, such as
dbplot
and tidypredict
,
simple_kmeans_db()
does not create a single
dplyr
code that can be extracted as SQL. The function
produces multiple, serial and dependent SQL statements that run
individually inside the database. Each statement uses the current
centroids, or centers, to estimate new centroids, and then it
uses those centroids in a consecutive SQL statement to see if there was
any variance. Effectively, this approach uses R not only as translation
layer, but also as an orchestration layer.
Thesimple_kmeans_db()
approach of using multiple and
consecutive SQL queries to find the optimal centers, additionally, in
KMeans clustering, it matters the order in which the each set of centers
is passed. This creates an imperative to find a way to cache the current
centers used in a long running job, in case the job is canceled or
fails. Starting from the centers that were calculated last, will mean
that re-starting the job will not being from “0”, but from a more
advanced, read closer, set of centers.
The safeguard implemented in this function is trough a file, called
kmeans.csv. Each cycle will update the file. The file name can
be changed by modifying the safeguard_file
argument.
Setting the argument to NULL will turn off the safeguard. The file will
be saved to the temporary directory of the R session..
In this example we will set the max_repats
to 10, so as
to artificially avoid finding the optimal means
<- db_flights %>%
km simple_kmeans_db(dep_time, distance, max_repeats = 10)
In the next run, the “kmeans.csv” file is passed as the
initial_kmeans
argument. This will make
simple_kmeans_db()
use those centers as the starting
point:
<- db_flights %>%
km simple_kmeans_db(dep_time, distance, initial_kmeans = read.csv(file.path(tempdir(), "kmeans.csv")))
The second run took 7 cycles to complete, which adds up to the 17 cycles that it initially took in the first example at the top of this article.