Forecasting El Niño-Southern Oscillation (ENSO)

In the present day, we use the convLSTM launched in a previous publish to foretell El Niño-Southern Oscillation (ENSO).

ENSO refers to a altering sample of sea floor temperatures and sea-level pressures occurring within the equatorial Pacific. From its three total states, most likely the best-known is El Niño. El Niño happens when floor water temperatures within the jap Pacific are larger than regular, and the sturdy winds that usually blow from east to west are unusually weak. The alternative circumstances are termed La Niña. Every part in-between is classed as regular.

ENSO has nice impression on the climate worldwide, and routinely harms ecosystems and societies via storms, droughts and flooding, probably leading to famines and financial crises. The perfect societies can do is attempt to adapt and mitigate extreme penalties. Such efforts are aided by correct forecasts, the additional forward the higher.

Right here, deep studying (DL) can doubtlessly assist: Variables like sea floor temperatures and pressures are given on a spatial grid – that of the earth – and as we all know, DL is nice at extracting spatial (e.g., picture) options. For ENSO prediction, architectures like convolutional neural networks (Ham, Kim, and Luo (2019a)) or convolutional-recurrent hybrids are habitually used. One such hybrid is simply our convLSTM; it operates on sequences of options given on a spatial grid. In the present day, thus, we’ll be coaching a mannequin for ENSO forecasting. This mannequin could have a convLSTM for its central ingredient.

Earlier than we begin, a word. Whereas our mannequin matches properly with architectures described within the related papers, the identical can’t be stated for quantity of coaching information used. For causes of practicality, we use precise observations solely; consequently, we find yourself with a small (relative to the duty) dataset. In distinction, analysis papers have a tendency to utilize local weather simulations, leading to considerably extra information to work with.

From the outset, then, we don’t count on stellar efficiency. Nonetheless, this could make for an fascinating case research, and a helpful code template for our readers to use to their very own information.

We’ll try to predict month-to-month common sea floor temperature within the Niño 3.4 area, as represented by the Niño 3.4 Index, plus categorization as considered one of El Niño, La Niña or impartial. Predictions will likely be based mostly on prior month-to-month sea floor temperatures spanning a big portion of the globe.

On the enter facet, public and ready-to-use information could also be downloaded from Tokyo Climate Center; as to prediction targets, we get hold of index and classification here.

Enter and goal information each are supplied month-to-month. They intersect within the time interval starting from 1891-01-01 to 2020-08-01; so that is the vary of dates we’ll be zooming in on.

Enter: Sea Floor Temperatures

Month-to-month sea floor temperatures are supplied in a latitude-longitude grid of decision 1°. Particulars of how the information had been processed can be found here.

Information recordsdata can be found in GRIB format; every file comprises averages computed for a single month. We are able to both obtain particular person recordsdata or generate a text file of URLs for obtain. In case you’d prefer to comply with together with the publish, you’ll discover the contents of the textual content file I generated within the appendix. When you’ve saved these URLs to a file, you’ll be able to have R get the recordsdata for you want so:

From R, we will learn GRIB recordsdata utilizing stars. For instance:

stars object with 2 dimensions and 1 attribute
 Min.   :-274.9  
 1st Qu.:-272.8  
 Median :-259.1  
 Imply   :-260.0  
 third Qu.:-248.4  
 Max.   :-242.8  
 NA's   :21001   
  from  to offset delta                       refsys level values    
x    1 360      0     1 Coordinate System importe...    NA   NULL [x]
y    1 180     90    -1 Coordinate System importe...    NA   NULL [y]

So on this GRIB file, we’ve got one attribute – which we all know to be sea floor temperature – on a two-dimensional grid. As to the latter, we will complement what stars tells us with more information discovered within the documentation:

The east-west grid factors run eastward from 0.5ºE to 0.5ºW, whereas the north-south grid factors run northward from 89.5ºS to 89.5ºN.

We word a couple of issues we’ll need to do with this information. For one, the temperatures appear to be given in Kelvin, however with minus indicators. We’ll take away the minus indicators and convert to levels Celsius for comfort. We’ll even have to consider what to do with the NAs that seem for all non-maritime coordinates.

Earlier than we get there although, we have to mix information from all recordsdata right into a single information body. This provides an extra dimension, time, starting from 1891/01/01 to 2020/01/12:

grb <- read_stars(
  file.path(grb_dir, map(readLines("recordsdata", warn = FALSE), basename)), alongside = "time") %>%
                    values = seq(as.Date("1891-01-01"), as.Date("2020-12-01"), by = "months"),
                    names = "time"

stars object with 3 dimensions and 1 attribute
attribute(s), abstract of first 1e+05 cells:
 Min.   :-274.9  
 1st Qu.:-273.3  
 Median :-258.8  
 Imply   :-260.0  
 third Qu.:-247.8  
 Max.   :-242.8  
 NA's   :33724   
     from   to offset delta                       refsys level                    values    
x       1  360      0     1 Coordinate System importe...    NA                      NULL [x]
y       1  180     90    -1 Coordinate System importe...    NA                      NULL [y]
time    1 1560     NA    NA                         Date    NA 1891-01-01,...,2020-12-01    

Let’s visually examine the spatial distribution of month-to-month temperatures for one 12 months, 2020:

ggplot() +
  geom_stars(information = grb %>% filter(between(time, as.Date("2020-01-01"), as.Date("2020-12-01"))), alpha = 0.8) +
  facet_wrap("time") +
  scale_fill_viridis() +
  coord_equal() +
  theme_map() +
  theme( = "none") 

Monthly sea surface temperatures, 2020/01/01 - 2020/01/12.

Determine 1: Month-to-month sea floor temperatures, 2020/01/01 – 2020/01/12.

Goal: Niño 3.4 Index

For the Niño 3.4 Index, we obtain the month-to-month data and, among the many supplied options, zoom in on two: the index itself (column NINO34_MEAN) and PHASE, which might be E (El Niño), L (La Niño) or N (impartial).

nino <- read_table2("ONI_NINO34_1854-2020.txt", skip = 9) %>%
  mutate(month = as.Date(paste0(YEAR, "-", `MON/MMM`, "-01"))) %>%
  choose(month, NINO34_MEAN, PHASE) %>%
  filter(between(month, as.Date("1891-01-01"), as.Date("2020-08-01"))) %>%
  mutate(phase_code = as.numeric(as.factor(PHASE)))


Subsequent, we take a look at learn how to get the information right into a format handy for coaching and prediction.


First, we take away all enter information for time limits the place floor fact information are nonetheless lacking.

Subsequent, as is completed by e.g. Ham, Kim, and Luo (2019b), we solely use grid factors between 55° south and 60° north. This has the extra benefit of decreasing reminiscence necessities.

sst <- grb %>% filter(between(y,-55, 60))

360, 115, 1560

As already alluded to, with the little information we’ve got we will’t count on a lot by way of generalization. Nonetheless, we put aside a small portion of the information for validation, since we’d like for this publish to function a helpful template for use with greater datasets.

From right here on, we work with R arrays.

sst_train <- as.tbl_cube.stars(sst_train)$mets[[1]]
sst_valid <- as.tbl_cube.stars(sst_valid)$mets[[1]]

Conversion to levels Celsius shouldn’t be strictly crucial, as preliminary experiments confirmed a slight efficiency enhance attributable to normalizing the enter, and we’re going to try this anyway. Nonetheless, it reads nicer to people than Kelvin.

sst_train <- sst_train + 273.15
quantile(sst_train, na.rm = TRUE)
     0%     25%     50%     75%    100% 
-1.8000 12.9975 21.8775 26.8200 34.3700 

By no means surprisingly, international warming is obvious from inspecting temperature distribution on the validation set (which was chosen to span the final thirty-one years).

sst_valid <- sst_valid + 273.15
quantile(sst_valid, na.rm = TRUE)
    0%    25%    50%    75%   100% 
-1.800 13.425 22.335 27.240 34.870 

The following-to-last step normalizes each units in keeping with coaching imply and variance.

train_mean <- mean(sst_train, na.rm = TRUE)
train_sd <- sd(sst_train, na.rm = TRUE)

sst_train <- (sst_train - train_mean) / train_sd

sst_valid <- (sst_valid - train_mean) / train_sd

Lastly, what ought to we do concerning the NA entries? We set them to zero, the (coaching set) imply. That might not be sufficient of an motion although: It means we’re feeding the community roughly 30% deceptive information. That is one thing we’re not accomplished with but.

sst_train[] <- 0
sst_valid[] <- 0


The goal information are cut up analogously. Let’s examine although: Are phases (categorizations) distributedly equally in each units?

nino_train <- nino %>% filter(month < as.Date("1990-01-01"))
nino_valid <- nino %>% filter(month >= as.Date("1990-01-01"))

nino_train %>% group_by(phase_code, PHASE) %>% summarise(rely = n(), avg = mean(NINO34_MEAN))
# A tibble: 3 x 4
# Teams:   phase_code [3]
  phase_code PHASE rely   avg
       <dbl> <chr> <int> <dbl>
1          1 E       301  27.7
2          2 L       333  25.6
3          3 N       554  26.7
nino_valid %>% group_by(phase_code, PHASE) %>% summarise(rely = n(), avg = mean(NINO34_MEAN))
# A tibble: 3 x 4
# Teams:   phase_code [3]
  phase_code PHASE rely   avg
       <dbl> <chr> <int> <dbl>
1          1 E        93  28.1
2          2 L        93  25.9
3          3 N       182  27.2

This doesn’t look too dangerous. In fact, we once more see the general rise in temperature, regardless of part.

Lastly, we normalize the index, similar as we did for the enter information.

train_mean_nino <- mean(nino_train$NINO34_MEAN)
train_sd_nino <- sd(nino_train$NINO34_MEAN)

nino_train <- nino_train %>% mutate(NINO34_MEAN = scale(NINO34_MEAN, middle = train_mean_nino, scale = train_sd_nino))
nino_valid <- nino_valid %>% mutate(NINO34_MEAN = scale(NINO34_MEAN, middle = train_mean_nino, scale = train_sd_nino))

On to the torch dataset.

The dataset is liable for accurately matching up inputs and targets.

Our objective is to take six months of worldwide sea floor temperatures and predict the Niño 3.4 Index for the next month. Enter-wise, the mannequin will count on the next format semantics:

batch_size * timesteps * width * top * channels, the place

  • batch_size is the variety of observations labored on in a single spherical of computations,

  • timesteps chains consecutive observations from adjoining months,

  • width and top collectively represent the spatial grid, and

  • channels corresponds to accessible visible channels within the “picture.”

In .getitem(), we choose the consecutive observations, beginning at a given index, and stack them in dimension one. (One, not two, as batches will solely begin to exist as soon as the dataloader comes into play.)

Now, what concerning the goal? Our final objective was – is – predicting the Niño 3.4 Index. Nonetheless, as you see we outline three targets: One is the index, as anticipated; an extra one holds the spatially-gridded sea floor temperatures for the prediction month. Why? Our fundamental instrument, probably the most distinguished constituent of the mannequin, will likely be a convLSTM, an structure designed for spatial prediction. Thus, to coach it effectively, we need to give it the chance to foretell values on a spatial grid. To date so good; however there’s another goal, the part/class. This was added for experimentation functions: Perhaps predicting each index and part helps in coaching?

Lastly, right here is the code for the dataset. In our experiments, we based mostly predictions on inputs from the previous six months (n_timesteps <- 6). It is a parameter you would possibly need to play with, although.

n_timesteps <- 6

enso_dataset <- dataset(
  title = "enso_dataset",
  initialize = perform(sst, nino, n_timesteps) {
    self$sst <- sst
    self$nino <- nino
    self$n_timesteps <- n_timesteps
  .getitem = perform(i) {
    x <- torch_tensor(self$sst[, , i:(n_timesteps + i - 1)]) # (360, 115, n_timesteps)
    x <- x$permute(c(3,1,2))$unsqueeze(2) # (n_timesteps, 1, 360, 115))
    y1 <- torch_tensor(self$sst[, , n_timesteps + i])$unsqueeze(1) # (1, 360, 115)
    y2 <- torch_tensor(self$nino$NINO34_MEAN[n_timesteps + i])
    y3 <- torch_tensor(self$nino$phase_code[n_timesteps + i])$squeeze()$to(torch_long())
    list(x = x, y1 = y1, y2 = y2, y3 = y3)
  .size = perform() {
    nrow(self$nino) - n_timesteps

valid_ds <- enso_dataset(sst_valid, nino_valid, n_timesteps)

After the customized dataset, we create the – fairly typical – dataloaders, making use of a batch measurement of 4.

batch_size <- 4

train_dl <- train_ds %>% dataloader(batch_size = batch_size, shuffle = TRUE)

valid_dl <- valid_ds %>% dataloader(batch_size = batch_size)

Subsequent, we proceed to mannequin creation.

The mannequin’s fundamental ingredient is the convLSTM launched in a prior post. For comfort, we reproduce the code within the appendix.

Apart from the convLSTM, the mannequin makes use of three convolutional layers, a batchnorm layer and 5 linear layers. The logic is the next.

First, the convLSTM job is to foretell the subsequent month’s sea floor temperatures on the spatial grid. For that, we virtually simply return its last state, – virtually: We use self$conv1 to scale back the quantity channels to at least one.

For predicting index and part, we then have to flatten the grid, as we require a single worth every. That is the place the extra conv layers are available. We do hope they’ll help in studying, however we additionally need to cut back the variety of parameters a bit, downsizing the grid (strides = 2 and strides = 3, resp.) a bit earlier than the upcoming torch_flatten().

As soon as we’ve got a flat construction, studying is shared between the duties of index and part prediction (self$linear), till lastly their paths cut up (self$cont and self$cat, resp.), they usually return their separate outputs.

(The batchnorm? I’ll touch upon that within the Discussion.)

mannequin <- nn_module(
  initialize = perform(channels_in,
                        convlstm_layers) {
    self$n_layers <- convlstm_layers
    self$convlstm <- convlstm(
      input_dim = channels_in,
      hidden_dims = convlstm_hidden,
      kernel_sizes = convlstm_kernel,
      n_layers = convlstm_layers
    self$conv1 <-
        in_channels = 32,
        out_channels = 1,
        kernel_size = 5,
        padding = 2
    self$conv2 <-
        in_channels = 32,
        out_channels = 32,
        kernel_size = 5,
        stride = 2
    self$conv3 <-
        in_channels = 32,
        out_channels = 32,
        kernel_size = 5,
        stride = 3
    self$linear <- nn_linear(33408, 64)
    self$b1 <- nn_batch_norm1d(num_features = 64)
    self$cont <- nn_linear(64, 128)
    self$cat <- nn_linear(64, 128)
    self$cont_output <- nn_linear(128, 1)
    self$cat_output <- nn_linear(128, 3)
  ahead = perform(x) {
    ret <- self$convlstm(x)
    layer_last_states <- ret[[2]]
    last_hidden <- layer_last_states[[self$n_layers]][[1]]
    next_sst <- last_hidden %>% self$conv1() 
    c2 <- last_hidden %>% self$conv2() 
    c3 <- c2 %>% self$conv3() 
    flat <- torch_flatten(c3, start_dim = 2)
    widespread <- self$linear(flat) %>% self$b3() %>% nnf_relu()

    next_temp <- widespread %>% self$cont() %>% nnf_relu() %>% self$cont_output()
    next_nino <- widespread %>% self$cat() %>% nnf_relu() %>% self$cat_output()
    list(next_sst, next_temp, next_nino)

Subsequent, we instantiate a fairly small-ish mannequin. You’re greater than welcome to experiment with bigger fashions, however coaching time in addition to GPU reminiscence necessities will enhance.

web <- mannequin(
  channels_in = 1,
  convlstm_hidden = c(16, 16, 32),
  convlstm_kernel = c(3, 3, 5),
  convlstm_layers = 3

system <- torch_device(if (cuda_is_available()) "cuda" else "cpu")

web <- web$to(system = system)
An `nn_module` containing 2,389,605 parameters.

── Modules ───────────────────────────────────────────────────────────────────────────────
● convlstm: <nn_module> #182,080 parameters
● conv1: <nn_conv2d> #801 parameters
● conv2: <nn_conv2d> #25,632 parameters
● conv3: <nn_conv2d> #25,632 parameters
● linear: <nn_linear> #2,138,176 parameters
● b1: <nn_batch_norm1d> #128 parameters
● cont: <nn_linear> #8,320 parameters
● cat: <nn_linear> #8,320 parameters
● cont_output: <nn_linear> #129 parameters
● cat_output: <nn_linear> #387 parameters

Now we have three mannequin outputs. How ought to we mix the losses?

Provided that the principle objective is predicting the index, and the opposite two outputs are basically means to an finish, I discovered the next mixture relatively efficient:

# weight for sea floor temperature prediction
lw_sst <- 0.2

# weight for prediction of El Nino 3.4 Index
lw_temp <- 0.4

# weight for part prediction
lw_nino <- 0.4

The coaching course of follows the sample seen in all torch posts to this point: For every epoch, loop over the coaching set, backpropagate, examine efficiency on validation set.

However, after we did the pre-processing, we had been conscious of an imminent drawback: the lacking temperatures for continental areas, which we set to zero. As a sole measure, this strategy is clearly inadequate. What if we had chosen to make use of latitude-dependent averages? Or interpolation? Each could also be higher than a worldwide common, however each have their issues as properly. Let’s no less than alleviate damaging penalties by not utilizing the respective pixels for spatial loss calculation. That is taken care of by the next line beneath:

sst_loss <- nnf_mse_loss(sst_output[sst_target != 0], sst_target[sst_target != 0])

Right here, then, is the whole coaching code.

optimizer <- optim_adam(web$parameters, lr = 0.001)

num_epochs <- 50

train_batch <- perform(b) {
  output <- web(b$x$to(system = system))
  sst_output <- output[[1]]
  sst_target <- b$y1$to(system = system)
  sst_loss <- nnf_mse_loss(sst_output[sst_target != 0], sst_target[sst_target != 0])
  temp_loss <- nnf_mse_loss(output[[2]], b$y2$to(system = system))
  nino_loss <- nnf_cross_entropy(output[[3]], b$y3$to(system = system))
  loss <- lw_sst * sst_loss + lw_temp * temp_loss + lw_nino * nino_loss

  list(sst_loss$merchandise(), temp_loss$merchandise(), nino_loss$merchandise(), loss$merchandise())

valid_batch <- perform(b) {
  output <- web(b$x$to(system = system))
  sst_output <- output[[1]]
  sst_target <- b$y1$to(system = system)
  sst_loss <- nnf_mse_loss(sst_output[sst_target != 0], sst_target[sst_target != 0])
  temp_loss <- nnf_mse_loss(output[[2]], b$y2$to(system = system))
  nino_loss <- nnf_cross_entropy(output[[3]], b$y3$to(system = system))
  loss <-
    lw_sst * sst_loss + lw_temp * temp_loss + lw_nino * nino_loss


for (epoch in 1:num_epochs) {
  train_loss_sst <- c()
  train_loss_temp <- c()
  train_loss_nino <- c()
  train_loss <- c()

  coro::loop(for (b in train_dl) {
    losses <- train_batch(b)
    train_loss_sst <- c(train_loss_sst, losses[[1]])
    train_loss_temp <- c(train_loss_temp, losses[[2]])
    train_loss_nino <- c(train_loss_nino, losses[[3]])
    train_loss <- c(train_loss, losses[[4]])
      "nEpoch %d, coaching: loss: %3.3f sst: %3.3f temp: %3.3f nino: %3.3f n",
      epoch, mean(train_loss), mean(train_loss_sst), mean(train_loss_temp), mean(train_loss_nino)
  valid_loss_sst <- c()
  valid_loss_temp <- c()
  valid_loss_nino <- c()
  valid_loss <- c()

  coro::loop(for (b in valid_dl) {
    losses <- valid_batch(b)
    valid_loss_sst <- c(valid_loss_sst, losses[[1]])
    valid_loss_temp <- c(valid_loss_temp, losses[[2]])
    valid_loss_nino <- c(valid_loss_nino, losses[[3]])
    valid_loss <- c(valid_loss, losses[[4]])
      "nEpoch %d, validation: loss: %3.3f sst: %3.3f temp: %3.3f nino: %3.3f n",
      epoch, mean(valid_loss), mean(valid_loss_sst), mean(valid_loss_temp), mean(valid_loss_nino)
  torch_save(web, paste0(
    "model_", epoch, "_", round(mean(train_loss), 3), "_", round(mean(valid_loss), 3), ".pt"

Once I ran this, efficiency on the coaching set decreased in a not-too-fast, however steady manner, whereas validation set efficiency stored fluctuating. For reference, complete (composite) losses regarded like this:

Epoch     Coaching    Validation
   10        0.336         0.633
   20        0.233         0.295
   30        0.135         0.461
   40        0.099         0.903
   50        0.061         0.727

Pondering of the scale of the validation set – thirty-one years, or equivalently, 372 information factors – these fluctuations might not be all too stunning.

Now losses are usually summary; let’s see what really will get predicted. We get hold of predictions for index values and phases like so …


pred_index <- c()
pred_phase <- c()

coro::loop(for (b in valid_dl) {

  output <- web(b$x$to(system = system))

  pred_index <- c(pred_index, output[[2]]$to(system = "cpu"))
  pred_phase <- rbind(pred_phase, as.matrix(output[[3]]$to(system = "cpu")))


… and mix these with the bottom fact, stripping off the primary six rows (six was the variety of timesteps used as predictors):

valid_perf <- data.frame(
  actual_temp = nino_valid$NINO34_MEAN[(batch_size + 1):nrow(nino_valid)] * train_sd_nino + train_mean_nino,
  actual_nino = factor(nino_valid$phase_code[(batch_size + 1):nrow(nino_valid)]),
  pred_temp = pred_index * train_sd_nino + train_mean_nino,
  pred_nino = factor(pred_phase %>% apply(1, which.max))

For the part, we will generate a confusion matrix:

yardstick::conf_mat(valid_perf, actual_nino, pred_nino)
Prediction   1   2   3
         1  70   0  43
         2   0  47  10
         3  23  46 123

This seems higher than anticipated (based mostly on the losses). Phases 1 and a couple of correspond to El Niño and La Niña, respectively, and these get sharply separated.

What concerning the Niño 3.4 Index? Let’s plot predictions versus floor fact:

valid_perf <- valid_perf %>% 
  choose(precise = actual_temp, predicted = pred_temp) %>% 
  add_column(month = seq(as.Date("1990-07-01"), as.Date("2020-08-01"), by = "months")) %>%
  pivot_longer(-month, names_to = "Index", values_to = "temperature")

ggplot(valid_perf, aes(x = month, y = temperature, coloration = Index)) +
  geom_line() +
  scale_color_manual(values = c("#006D6F", "#B2FFFF")) +

Nino 3.4 Index: Ground truth vs. predictions (validation set).

Determine 2: Nino 3.4 Index: Floor fact vs. predictions (validation set).

This doesn’t look dangerous both. Nonetheless, we have to take into account that we’re predicting only a single time step forward. We most likely shouldn’t overestimate the outcomes. Which leads on to the dialogue.

When working with small quantities of information, lots might be discovered by quick-ish experimentation. Nonetheless, when on the similar time, the duty is complicated, one must be cautious extrapolating.

For instance, well-established regularizers akin to batchnorm and dropout, whereas supposed to enhance generalization to the validation set, could prove to severely impede coaching itself. That is the story behind the one batchnorm layer I stored (I did strive having extra), and it’s also why there isn’t any dropout.

One lesson to be taught from this expertise then is: Be certain that the quantity of information matches the complexity of the duty. That is what we see within the ENSO prediction papers revealed on arxiv.

If we should always deal with the outcomes with warning, why even publish the publish?

For one, it exhibits an software of convLSTM to real-world information, using a fairly complicated structure and illustrating strategies like customized losses and loss masking. Comparable architectures and techniques must be relevant to a variety of real-world duties – principally, at any time when predictors in a time-series drawback are given on a spatial grid.

Secondly, the applying itself – forecasting an atmospheric phenomenon that significantly impacts ecosystems in addition to human well-being – looks as if a wonderful use of deep studying. Functions like these stand out as all of the extra worthwhile as the identical can’t be stated of every part deep studying is – and will likely be, barring efficient regulation – used for.

Thanks for studying!

A1: Record of GRB recordsdata

To be put right into a textual content file to be used with purrr::stroll( … obtain.file … ).

A2: convlstm code

For an in-depth clarification of convlstm, see the blog post.


convlstm_cell <- nn_module(
  initialize = perform(input_dim, hidden_dim, kernel_size, bias) {
    self$hidden_dim <- hidden_dim
    padding <- kernel_size %/% 2
    self$conv <- nn_conv2d(
      in_channels = input_dim + self$hidden_dim,
      # for every of enter, neglect, output, and cell gates
      out_channels = 4 * self$hidden_dim,
      kernel_size = kernel_size,
      padding = padding,
      bias = bias
  ahead = perform(x, prev_states) {

    h_prev <- prev_states[[1]]
    c_prev <- prev_states[[2]]
    mixed <- torch_cat(list(x, h_prev), dim = 2)  # concatenate alongside channel axis
    combined_conv <- self$conv(mixed)
    gate_convs <- torch_split(combined_conv, self$hidden_dim, dim = 2)
    cc_i <- gate_convs[[1]]
    cc_f <- gate_convs[[2]]
    cc_o <- gate_convs[[3]]
    cc_g <- gate_convs[[4]]
    # enter, neglect, output, and cell gates (akin to torch's LSTM)
    i <- torch_sigmoid(cc_i)
    f <- torch_sigmoid(cc_f)
    o <- torch_sigmoid(cc_o)
    g <- torch_tanh(cc_g)
    # cell state
    c_next <- f * c_prev + i * g
    # hidden state
    h_next <- o * torch_tanh(c_next)
    list(h_next, c_next)
  init_hidden = perform(batch_size, top, width) {
    list(torch_zeros(batch_size, self$hidden_dim, top, width, system = self$conv$weight$system),
         torch_zeros(batch_size, self$hidden_dim, top, width, system = self$conv$weight$system))

convlstm <- nn_module(
  initialize = perform(input_dim, hidden_dims, kernel_sizes, n_layers, bias = TRUE) {
    self$n_layers <- n_layers
    self$cell_list <- nn_module_list()
    for (i in 1:n_layers) {
      cur_input_dim <- if (i == 1) input_dim else hidden_dims[i - 1]
      self$cell_list$append(convlstm_cell(cur_input_dim, hidden_dims[i], kernel_sizes[i], bias))
  # we all the time assume batch-first
  ahead = perform(x) {
    batch_size <- x$measurement()[1]
    seq_len <- x$measurement()[2]
    top <- x$measurement()[4]
    width <- x$measurement()[5]
    # initialize hidden states
    init_hidden <- vector(mode = "checklist", size = self$n_layers)
    for (i in 1:self$n_layers) {
      init_hidden[[i]] <- self$cell_list[[i]]$init_hidden(batch_size, top, width)
    # checklist containing the outputs, of size seq_len, for every layer
    # this is identical as h, at every step within the sequence
    layer_output_list <- vector(mode = "checklist", size = self$n_layers)
    # checklist containing the final states (h, c) for every layer
    layer_state_list <- vector(mode = "checklist", size = self$n_layers)

    cur_layer_input <- x
    hidden_states <- init_hidden
    # loop over layers
    for (i in 1:self$n_layers) {
      # each layer's hidden state begins from 0 (non-stateful)
      h_c <- hidden_states[[i]]
      h <- h_c[[1]]
      c <- h_c[[2]]
      # outputs, of size seq_len, for this layer
      # equivalently, checklist of h states for every time step
      output_sequence <- vector(mode = "checklist", size = seq_len)
      # loop over timesteps
      for (t in 1:seq_len) {
        h_c <- self$cell_list[[i]](cur_layer_input[ , t, , , ], list(h, c))
        h <- h_c[[1]]
        c <- h_c[[2]]
        # maintain monitor of output (h) for each timestep
        # h has dim (batch_size, hidden_size, top, width)
        output_sequence[[t]] <- h

      # stack hs for all timesteps over seq_len dimension
      # stacked_outputs has dim (batch_size, seq_len, hidden_size, top, width)
      # similar as enter to ahead (x)
      stacked_outputs <- torch_stack(output_sequence, dim = 2)
      # cross the checklist of outputs (hs) to subsequent layer
      cur_layer_input <- stacked_outputs
      # maintain monitor of checklist of outputs or this layer
      layer_output_list[[i]] <- stacked_outputs
      # maintain monitor of final state for this layer
      layer_state_list[[i]] <- list(h, c)
    list(layer_output_list, layer_state_list)
Ham, Yoo-Geun, Jeong-Hwan Kim, and Jing-Jia Luo. 2019b. Deep studying for multi-year ENSO forecasts 573 (7775): 568–72.
———. 2019a. Deep studying for multi-year ENSO forecasts 573 (7775): 568–72.

Easy audio classification with torch

Convolutional LSTM for spatial forecasting