K Means models

Edgar Ruiz

2022-08-16

Intro

The simple_kmeans_db() function enables running the KMeans model inside the database. It uses dplyr programming to abstract the steps needed produce a model, so that it can then be translated into SQL statements in the background.

Example setup

In this example, a simple RSQlite database will be use to load the flights data from the nycflights13 library.

library(dplyr)

con <- DBI::dbConnect(RSQLite::SQLite(), path = ":memory:")
RSQLite::initExtension(con)

db_flights <- copy_to(con, nycflights13::flights, "flights")

Running Kmeans clustering

The function simple_kmeans_db() can use with local data, or a remote table, such as the db_flights variable that is a pointer to the “flights” table inside the SQLite database. When piping to the function, the only other required arguments are two or more fields separated by comma. Because it uses ‘tidyeval’, the variable name auto-completion will work.

library(modeldb)

km <- db_flights %>%
  simple_kmeans_db(dep_time, distance)

The simple_kmeans_db() function uses a progress bar to show you the current cycle, the maximum cycles it’s expected to run, the current difference between the previous cycle and the current cycle, and the running time. The loop will stop once it wither has two matching consecutive cycles, or if it reaches the maximum number of cycles, as determined by the max_repeats argument.

The final centers are are stored in the centers variable of the returned object

km$centers
#> NULL

The latest results are stored in the tbl variable of the returned object. The type of the returned table will match the type of the original source, so if it is a remote source, such as database table, then tbl will be a class tbl_sql. This will allow us to do two thing:

dbplyr::remote_query(km)
#> <SQL> SELECT `k_center`, `k_dep_time`, `k_distance`, `year`, `month`, `day`, `dep_time`, `sched_dep_time`, `dep_delay`, `arr_time`, `sched_arr_time`, `arr_delay`, `carrier`, `flight`, `tailnum`, `origin`, `dest`, `air_time`, `distance`, `hour`, `minute`, `time_hour`
#> FROM (SELECT `year`, `month`, `day`, `dep_time`, `sched_dep_time`, `dep_delay`, `arr_time`, `sched_arr_time`, `arr_delay`, `carrier`, `flight`, `tailnum`, `origin`, `dest`, `air_time`, `distance`, `hour`, `minute`, `time_hour`, `LHS`.`k_center` AS `k_center`, `k_dep_time`, `k_distance`
#> FROM (SELECT `year`, `month`, `day`, `dep_time`, `sched_dep_time`, `dep_delay`, `arr_time`, `sched_arr_time`, `arr_delay`, `carrier`, `flight`, `tailnum`, `origin`, `dest`, `air_time`, `distance`, `hour`, `minute`, `time_hour`, `center` AS `k_center`
#> FROM (SELECT `year`, `month`, `day`, `dep_time`, `sched_dep_time`, `dep_delay`, `arr_time`, `sched_arr_time`, `arr_delay`, `carrier`, `flight`, `tailnum`, `origin`, `dest`, `air_time`, `distance`, `hour`, `minute`, `time_hour`, `center_1`, `center_2`, `center_3`, CASE
#> WHEN (`center_1` >= `center_1` AND `center_1` < `center_2` AND `center_1` < `center_3`) THEN ('center_1')
#> WHEN (`center_2` < `center_1` AND `center_2` >= `center_2` AND `center_2` < `center_3`) THEN ('center_2')
#> WHEN (`center_3` < `center_1` AND `center_3` < `center_2` AND `center_3` >= `center_3`) THEN ('center_3')
#> END AS `center`
#> FROM (SELECT `year`, `month`, `day`, `dep_time`, `sched_dep_time`, `dep_delay`, `arr_time`, `sched_arr_time`, `arr_delay`, `carrier`, `flight`, `tailnum`, `origin`, `dest`, `air_time`, `distance`, `hour`, `minute`, `time_hour`, SQRT(((889.757881651311 - `dep_time`) * (889.757881651311 - `dep_time`)) + ((791.286862996562 - `distance`) * (791.286862996562 - `distance`))) AS `center_1`, SQRT(((1391.08534916316 - `dep_time`) * (1391.08534916316 - `dep_time`)) + ((2355.04462033144 - `distance`) * (2355.04462033144 - `distance`))) AS `center_2`, SQRT(((1745.74853136521 - `dep_time`) * (1745.74853136521 - `dep_time`)) + ((718.043515631104 - `distance`) * (718.043515631104 - `distance`))) AS `center_3`
#> FROM `flights`))
#> WHERE (NOT(((`center`) IS NULL)))) AS `LHS`
#> LEFT JOIN (SELECT `center` AS `k_center`, `dep_time` AS `k_dep_time`, `distance` AS `k_distance`
#> FROM (SELECT `center`, AVG(`dep_time`) AS `dep_time`, AVG(`distance`) AS `distance`
#> FROM (SELECT `dep_time`, `distance`, `center`
#> FROM (SELECT `year`, `month`, `day`, `dep_time`, `sched_dep_time`, `dep_delay`, `arr_time`, `sched_arr_time`, `arr_delay`, `carrier`, `flight`, `tailnum`, `origin`, `dest`, `air_time`, `distance`, `hour`, `minute`, `time_hour`, `center_1`, `center_2`, `center_3`, CASE
#> WHEN (`center_1` >= `center_1` AND `center_1` < `center_2` AND `center_1` < `center_3`) THEN ('center_1')
#> WHEN (`center_2` < `center_1` AND `center_2` >= `center_2` AND `center_2` < `center_3`) THEN ('center_2')
#> WHEN (`center_3` < `center_1` AND `center_3` < `center_2` AND `center_3` >= `center_3`) THEN ('center_3')
#> END AS `center`
#> FROM (SELECT `year`, `month`, `day`, `dep_time`, `sched_dep_time`, `dep_delay`, `arr_time`, `sched_arr_time`, `arr_delay`, `carrier`, `flight`, `tailnum`, `origin`, `dest`, `air_time`, `distance`, `hour`, `minute`, `time_hour`, SQRT(((889.757881651311 - `dep_time`) * (889.757881651311 - `dep_time`)) + ((791.286862996562 - `distance`) * (791.286862996562 - `distance`))) AS `center_1`, SQRT(((1391.08534916316 - `dep_time`) * (1391.08534916316 - `dep_time`)) + ((2355.04462033144 - `distance`) * (2355.04462033144 - `distance`))) AS `center_2`, SQRT(((1745.74853136521 - `dep_time`) * (1745.74853136521 - `dep_time`)) + ((718.043515631104 - `distance`) * (718.043515631104 - `distance`))) AS `center_3`
#> FROM `flights`))
#> WHERE (NOT(((`center`) IS NULL))))
#> GROUP BY `center`)) AS `RHS`
#> ON (`LHS`.`k_center` = `RHS`.`k_center`)
#> )

Under the hood

The simple_kmeans_db() function uses dplyr and ‘tidyeval’ to run the KMeans algorithm. This means that when combined with dbplyr, the routines can be run inside a database.

Unlike other packages that use this same methodology, such as dbplot and tidypredict, simple_kmeans_db() does not create a single dplyr code that can be extracted as SQL. The function produces multiple, serial and dependent SQL statements that run individually inside the database. Each statement uses the current centroids, or centers, to estimate new centroids, and then it uses those centroids in a consecutive SQL statement to see if there was any variance. Effectively, this approach uses R not only as translation layer, but also as an orchestration layer.

Safeguards for long running jobs

Thesimple_kmeans_db() approach of using multiple and consecutive SQL queries to find the optimal centers, additionally, in KMeans clustering, it matters the order in which the each set of centers is passed. This creates an imperative to find a way to cache the current centers used in a long running job, in case the job is canceled or fails. Starting from the centers that were calculated last, will mean that re-starting the job will not being from “0”, but from a more advanced, read closer, set of centers.

The safeguard implemented in this function is trough a file, called kmeans.csv. Each cycle will update the file. The file name can be changed by modifying the safeguard_file argument. Setting the argument to NULL will turn off the safeguard. The file will be saved to the temporary directory of the R session..

In this example we will set the max_repats to 10, so as to artificially avoid finding the optimal means

km <- db_flights %>%
  simple_kmeans_db(dep_time, distance, max_repeats = 10)

In the next run, the “kmeans.csv” file is passed as the initial_kmeans argument. This will make simple_kmeans_db() use those centers as the starting point:

km <- db_flights %>%
  simple_kmeans_db(dep_time, distance, initial_kmeans = read.csv(file.path(tempdir(), "kmeans.csv")))

The second run took 7 cycles to complete, which adds up to the 17 cycles that it initially took in the first example at the top of this article.