Torch Integration

Custom serialization functions may be registered to handle external pointer type reference objects.

This allows tensors from the {torch} package to be used seamlessly in ‘mirai’ computations.

Setup Steps

  1. Register the serialization and unserialization functions as a list supplied to serialization().

  2. Set up dameons - this may be done before or after setting serialization().

  3. Use everywhere() to make the {torch} package available on all dameons (for convenience, optional).

library(mirai)
library(torch)

serialization(refhook = list(torch:::torch_serialize, torch::torch_load),
              class = "torch_tensor",
              list = TRUE)
daemons(1)
#> [1] 1
everywhere(library(torch))

Example Usage

The below example creates a convolutional neural network using torch::nn_module().

A set of model parameters is also specified.

The model specification and parameters are then passed to and initialized within a ‘mirai’.

model <- nn_module(
  initialize = function(in_size, out_size) {
    self$conv1 <- nn_conv2d(in_size, out_size, 5)
    self$conv2 <- nn_conv2d(in_size, out_size, 5)
  },
  forward = function(x) {
    x <- self$conv1(x)
    x <- nnf_relu(x)
    x <- self$conv2(x)
    x <- nnf_relu(x)
    x
  }
)

params <- list(in_size = 1, out_size = 20)

m <- mirai(do.call(model, params), .args = list(model, params))

call_mirai(m)$data
#> An `nn_module` containing 1,040 parameters.
#> 
#> ── Modules ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
#> • conv1: <nn_conv2d> #520 parameters
#> • conv2: <nn_conv2d> #520 parameters

The returned model is an object containing many tensor elements.

m$data$parameters$conv1.weight
#> torch_tensor
#> (1,1,.,.) = 
#>  -0.1377 -0.1946 -0.0898 -0.0106  0.1031
#>   0.0351 -0.1507  0.0192 -0.0240  0.1700
#>   0.1107 -0.1232 -0.0690 -0.0868  0.0884
#>  -0.0267 -0.0022  0.1997 -0.1383  0.0435
#>  -0.1809 -0.1085 -0.0297 -0.0644 -0.0401
#> 
#> (2,1,.,.) = 
#>  -0.1002  0.1828  0.0886 -0.1793  0.1938
#>  -0.0046  0.1006 -0.0480 -0.0389  0.0083
#>   0.0761  0.0242  0.0283 -0.1859  0.1711
#>   0.0402 -0.1847  0.1351  0.1842  0.0094
#>   0.1114 -0.1828 -0.1846 -0.0650  0.1380
#> 
#> (3,1,.,.) = 
#>   0.1659 -0.0655  0.0936  0.1089  0.0514
#>  -0.0058  0.1683 -0.0303 -0.0817 -0.1813
#>  -0.0236 -0.1817 -0.1238  0.1651  0.1937
#>  -0.1627 -0.0650  0.1760  0.0215 -0.0887
#>  -0.0851  0.1430  0.1322 -0.1617 -0.1646
#> 
#> (4,1,.,.) = 
#>  -0.0520 -0.0200  0.0354 -0.1181 -0.1892
#>   0.1204  0.0877  0.1280 -0.0056  0.0088
#>  -0.1765 -0.0036 -0.0664  0.0417  0.0189
#>   0.1327 -0.1689  0.0132  0.0849 -0.0634
#>   0.0072 -0.1516  0.0633  0.0185  0.0992
#> 
#> (5,1,.,.) = 
#>   0.0992  0.0912 -0.1982  0.1880 -0.0986
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{20,1,5,5} ][ requires_grad = TRUE ]

It is usual for model parameters to then be passed to an optimiser.

This can also be initialized within a ‘mirai’ process.

optim <- mirai(optim_rmsprop(params = params), params = m$data$parameters)

call_mirai(optim)$data
#> <optim_rmsprop>
#>   Inherits from: <torch_optimizer>
#>   Public:
#>     add_param_group: function (param_group) 
#>     clone: function (deep = FALSE) 
#>     defaults: list
#>     initialize: function (params, lr = 0.01, alpha = 0.99, eps = 1e-08, weight_decay = 0, 
#>     load_state_dict: function (state_dict, ..., .refer_to_state_dict = FALSE) 
#>     param_groups: list
#>     state: State, R6
#>     state_dict: function () 
#>     step: function (closure = NULL) 
#>     zero_grad: function () 
#>   Private:
#>     step_helper: function (closure, loop_fun)

daemons(0)
#> [1] 0

Above, tensors and complex objects containing tensors were passed seamlessly between host and daemon processes, in the same way as any other R object.

The custom serialization in {mirai} leverages R’s own native ‘refhook’ mechanism to allow such completely transparent usage.

It is designed to be fast and efficient, minimising data copies and using the serialization methods from the ‘torch’ package directly.