9 분 소요

1.1 Deep sequence model

Definition 1.1 (Informal). We use sequence model to refer to a parameterized map on sequences $y=f_\theta(x)$ where inputs and outputs $x$, $y$ are sequences of length $L$ of feature vectors in $\mathbb{R}^D$, and $\theta$ are parameters to be learned through gradient descent.

Main challenges

General-Purpose Capabilities

  • RNNs: stateful settings that require rapid updating of a hidden state, such as online processing tasks and reinforcement learning.
  • CNNs: modeling uniformly-sampled perceptual signals such as audio, images, and videos.
  • Transformers: modeling dense, complex interactions in domains such as language.
  • NDEs: handling atypical time-series settings such as missing or irregularly-sampled data.

Computational Efficiency

At training time, tasks can generally be formulated with loss functions over entire input sequences, where the central algorithmic concern is how to compute the forward pass efficiently.

At inference time (deploying the model after it has been trained) the setting may change; for example, in online processing or autoregressive generation settings, inputs are only revealed one timestep at a time, and the model must be able to efficiently process these sequentially.

Long-range Dependencies (LRD)

In particular, difficulties can arise from being unable to capture the interactions in the data, such as if the model has a finite context window, or from optimization issues, such as the vanishing gradients problem when backpropagating through a long computation graph in recurrent models

1.2 State Space Sequence Model

1.2.1 A General-Purpose Sequence Model

Basically, SSMs are (1) continuous, (2) recurrent and (3) convolutional where SSMs incorporate most of the existing architecture of sequence model.

1.2.2 Efficient Computation with Structured SSMs (S4)

For an implicit latent state $x(t)\in\mathbb{R}^N$ with sequence length $L$, computing the full latent state $x$ alone requires $O(N^2L)$ operations and $O(NL)$ space

Then we can impose structure on the state matrix $A$ of Eqn. (1.1), so called structured state space sequencemodel (S4).

  • Then we aim to achieve $O(N+L)$ time/space complexity.

1.2.3 Addressing Long-Range Dependencies with HIPPO

HIPPO can handle with the long range dependency issue, which utilize High-Order Polynomial Projection Operator (HIPPO).

Basically, we can view the sequence modeling as a special case of online function approximation (or memorization). In particular, every existing sequence model, such as RNNs or TFs, mainly focuses on how to formulate the contextual information which is used to computing represent the current input.

Based on the polynomial projection schema, orthogoal SSM with variants of basis are considered.

2.1 Background: The Sequence Model Framework

2.1.1 Learning with Deep Sequence Models

  • An embedding or encoder layer $e_{\text{in}} : \mathbb{R}^{D_\text{in}} → \mathbb{R}^D$ that operates position-wise on the sequence.

    Typical examples are an embedding lookup table for categorical inputs, or a linear projection parameterized by a weight matrix $\theta_\text{in} \in \mathbb{R}^{D_\text{in} \times D}$ for continuous inputs.(In this case, $e_{\text{in}}$is a matrix multiplication $e_{\text{in}}(x) = x \theta_\text{in}$)

  • A neural network block $f_\theta : \mathbb{R}^{L×D} → \mathbb{R}^{L×D}$ built around a core sequence transformation. It is typically repeated $K$ times with independent parameters $f₁ = f_{θ₁}, …, f_K = f_{θ_K}$.
  • A projection or decoder layer $e_\text{out} : \mathbb{R}^D → \mathbb{R}^{D_\text{out}}$ that operates position-wise, often a linear projection.
  • A task-dependent loss function $ℓ : \mathbb{R}^{D_\text{out}} \times \mathbb{R}^{D_\text{out}} → \mathbb{R}$.

The dimensions $D_\text{in}$ and $D_\text{out}$ also depend on the task, usually based on the dimension of the inputs $x ∈ R^{L×D_\text{in}}$ and outputs $y ∈ \mathbb{R}^{L×D_\text{out}}$.

The overall model is $f_θ = e_\text{out} ∘ f_K ∘ ⋯ ∘ f₁ ∘ e_\text{in} : \mathbb{R}^{L×D_\text{in}} → \mathbb{R}^{L×D_\text{out}}$

with parameters $θ = (θ_\text{in}, θ₁, …, θ_K, θ_\text{out})$. Computing this function is often called computing the forward pass of the model.

Definition 2.4. A causal sequence model $f_\theta$ is one in which $y_k$ depends only on $x_0, x_1, . . . , x_k$.

2.2 Background: State Space Models

2.2.1 Linear Time Invariant (LTI) SSMs

Basically, state matrix, input matrix and output matrix is considered to be time-varying matrix where they are functions of step $t$. Under the Linear Time-Invariant (LTI) assumption, they become time-independent.

  • It is notable that the equivalance of LTI-SSM and convolution operation is essentially based on the LTI assumption where we can always induce correseponding $K$ for given $A$, $B$ and $C$

2.3 State Space Sequence Models

2.3.1 The Continuous Representation (Discretization)

Data in the real world is discrete instead of continuous. That is, equation the input signal (or input sequence) $u = (u_0, u_1, u_2, \ldots)$ instead of continuous function $u(t)$.

An additional step size parameter $∆$ is required that represents the resolution of the input. Conceptually, the inputs $u_k$ can be viewed as uniformly-spaced samples from an implicit underlying continuous signal $u(t)$, where $u_k = u(k∆)$. We also call this a timescale.

Analogous to the fact that the SSM has equivalent forms either as a differential equation or a continuous convolution, the discrete-time SSM can be computed either as a recurrence or a discrete convolution.

Regarding discretization, Euler’s method is one of the most simplest method

i.e. $x’(t) = f(x(t))$ into the first-order approximation $x_k = x_{k-1} + \Delta f(x_{k-1})$

Then we have

$\begin{aligned} x_k &= x_{k-1} + \Delta (A x_{k-1} + B u_k)
&= (I + \Delta A) x_{k-1} + (\Delta B) u_k
&= \overline{A} x_{k-1} + \overline{B} u_k \end{aligned}$

for $\overline{A} := I + \Delta A, \quad \overline{B} := \Delta B$

Alternative methods include

  • Bilinear (Tustin)

    $\bar{A} = (I - \tfrac{\Delta}{2} A)^{-1} (I + \tfrac{\Delta}{2} A)$

    $\bar{B} = (I - \tfrac{\Delta}{2} A)^{-1} , (\Delta B)$

  • ZOH (Zero-Order Hold)

    $\bar{A} = e^{\Delta A}$

    $\bar{B} = (\Delta A)^{-1} \left(e^{\Delta A} - I\right) , \Delta B$

Essentially, discreticzation is nothing but formulating matrices and timescale into discretized values.

i.e $(\Delta, A, B) \mapsto (\bar{A},\bar{B})$ whose mapping function can be Euler’s method, Bilinear or ZOH

2.3.2 The Recurrent Representation (Efficient Inference)

As a result, we have discretized SSM

  • Note that RNNs can be considered as a special case of discrete SSM!

2.3.3 The Convolutional Representation (Efficient Training)

  • The kernel $K$ is called State Space Kernel (SSK)
  • And modern CNNs are considered to be composed of finite number of convolution kernels, while SSKs are infinitely long

  • Basically, SSK can be infinitely long but we use the kernel truncated with input length $L$.

2.3.4 Summary of SSM Representations

2.3.5 A Note on SSM Dimensions

Essentially, a sequence model is defined as a map of $\mathbb{R}^{L\times D}\mapsto \mathbb{R}^{L\times D}$ where $L$ is the length of input and $D$ is the dimension of the input.

In contrast, SSM or SSSM $\mathbb{R}^{L\times M}\mapsto \mathbb{R}^{L\times M}$ where $M$ must divide $D$. This can be considered as a “multi” head, where (S)SSM computes the subspace of input dimension $D$-space, which is called multi-head SSM with $H = D/N$ heads, either.

  • RNN is considered to be maximal MIMO

3.1 Motivation: Computational Difficulty of SSMs

3.1.1 Discussion: General Recurrent Computation

  • The thing is that naive SSM cannot achieve the optimal complexity $O(N)$. There are several components that causes inefficient computation. For example, the disretization can be one of the causes.
  • That’s why we need a structured formulation for efficient Matrix Vector Multiplications (MVMs)!

3.1.2 Discussion: General Convolutional Computation

  • For convolution mode, computing SSK requires $O(N^2L)$ computations with $O(NL)$ space
  • Moreover, continual MVM requires numerical stability

3.1.3 Structured SSMs

Matrix structure address the issue of time/space complexity with numerical stabilty and Structured State Space Sequece (S4) model handle these problems.

3.2 Diagonally Structured State Space Models

  • The thing is that all we need to do is diagonalize the matrices. To begin with, we can find equivalent SSMs while chaning the basis of the state $x$.
  • In other words, we can find equivalent matrices which operates in the same way, while we can have additional structure on the matrices. In particular, we can have efficient SSM, solving the computational problem, by properly choosing matrices.

3.2.1 S4D: Diagonal SSM Algorithms

Recurrence mode

  • Easy to operate because every computation can be done in an element-wise way.

Convolution mode

Time and Space Complexity

  • Vandermonde matrix can be computed in $O(L+N)$ operations with $O(L+N)$ space.

3.2.2 Complete Implementation Example

3.3 The Diagonal Plus Low-Rank (DPLR) Parameterization

In this section, diagonal plus low-rank (DPLR) SSMs are introduced, where we can compute SSMs in an efficient way.

3.3.1 Overview of the DPLR State Space Kernel Algorithm

  • Generating Function for SSK with input length $L$
  • Woodbury identity that replace power operation with inverse operation.
  • Cauchy kernel to compute diagonal matrices.

3.3.2 S4-DPLR Algorithms and Computational Complexity

(Informal)

We cannot always find the diagonalized matrices for SSM. In that point, Diagonal Plus Low-Rank (DPLR) structure are introduced as follows

$A = \Lambda + PQ^$ s.t. $r(PQ^) \leq r « N$

  • $A \in \mathbb{R}^{N\times N}$
  • $D \in \mathbb{R}^{N}$
  • $P,Q \in \mathbb{R}^{N\times r}$

3.3.3 Hurwitz (Stable) DPLR Form

Note that DPLR structure provides an efficient operation regardless of its numerical stability. And Hurwitz matrices handle this issue.

Definition 3.7. A Hurwitz matrix A is one where every eigenvalue has negative real part.

  • Intuitively, if $A$ is a Hurwitz matrix, then the cumulative structure of SSM, in which the state matrix $A$ is repeatedly multipled, can be numerically stable where $K(t)=Ce^{At}B$ for $t\to \infty$.

However, ensuring that our matrices are Hurwitz matrices is difficult.

(Informal)

To ensure numerical stability, we use $A:= \Lambda - PP^*$, that can be computed easily.

(Informal) Remarks

  • If we use naive DLPR form, i.e. $A = \Lambda - PP^*$, then we cannot use the beneficial properties of Vandermonde matrix, which is used to compute the kernel matrix, SSK.
  • That’s why we use differerent algorithm to compute kernel matrix while using DLPR form

  • In this point, we use Kernel Generating Function, Woodbury identity and Cauchy kernel!

Point 1. Kernel Generating Function

  • Practically saying, we need a truncated kernel for input length $L.$
  • Then we need to compute $L$ kernels to implement convolution.

Kernel generating function is nothing but computing the equivalent output even if we do not exactly compute each kernel $K_t$

  • Details

    Actually, naive SSM operates in the time space. In this setup, the convolution computation requires $O(N^2L)$ operation for the input length $L$.

    However, if we compute in frequency space, then we can compute the results as follows.

    $\operatorname{FFT}(y) = \operatorname{FFT}(u*\bar{K})$

    which requires $L\log L$ computation.

    Then we denote the corresponding input $z$. The we have $G(z) := \operatorname{FFT}(z*\bar{K})$.

    It is notable that the kernel generating function $G(z)$ computes the output even if we do not exactly compute each $K_t$.

    That is, $G(z):= \sum_{t=0}^\infty K_t z^t = \sum_{t=0}^\infty (CA^tB) z^t$

    This can be represented in

    $G(z)= C\left(\sum_{t=0}^\infty (zA)^t\right)B = C(I-zA)^{-1}B$

  • proof

By using Kernel Generating Functions, we do not have compute the power operation. Instead, all we ned to do is computing the single inverse operation.

Point 2 .Woodbury Identity

Until now, by using Kernel Generating Function, we can compute the output without computing the kernels explicitly. However, recall that our state matrix is defined in Hurwitz DPLR form, i.e. $A=\Lambda - PP^*$. Then we need to compute the inverse of function of $A$.

  • proof

  • Recall that $R(z;\Lambda)$ is easy to compute for the diagonal matrix $\Lambda$.

By using Woodbury Identity, we can compute the kernel generating function when the matrix is provided in Hurwitz DPLR form. However, still requires $O(N^2 L)$ operations for computing $\hat{K}(z)$.

Point 3. Cauchy Matrix

  • It is notable that this derivation is based on the similarity of the operation used in Woodbury Identity based formulation with the Cauchy matrix.
  • That is, interpreting the Kernel generating function of Hurwitz DPLR matrix into Cauchy matrix, we can efficiently implements the function!

Now then, we can easily compute the kernel generating function when the matrix is provided in Hurwitz DPLR form in a $O((M+N)\log (M+N))$ complexity

3.4 Additional Parameterization Details

3.4.1 Discretization

Recallt that discretization can be implemented via (1) Euler’s method, (2) ZOH or (3) bilinear method. Then compatibility with DPLR parameterization should be considered.

3.4.2 Parameterization of A

Recall that $A$ should is Hurwitz, i.e. real part of $A$ shoul be negative. To do so $A=-\exp(A_\text{Re}) + i \cdot A_\text{Im}$. It is notable that it does not have to use exponential function. Instead, we can use ReLU function.

In addition, $A$ should be expressed in DPLR form either

댓글남기기