Abstractions of FSTs with Monads

Finite state machines (FSTs) are usually formulated as a tuple comprising of the following components:

We encode this in Scala:

trait Transducer[-A, +B] {
  type S
  def initialStates: Set[S]
  def finalStates: Set[S]
  def next(s: S, a: A): Set[(S, B)]
}

Here we encode the state type as an existential type: they are internal and of no interest to the users of FSTs. Note that the output of the next function is a Set[(S, B)]: it encodes multiple possible transition outcomes given a state and input .

One of the most interesting computational aspects of FSTs is that they can be composed (unlike RNNs). Mathematically, given FST and , we can compose these two into one FST , where strings is accepted only if there exists such that .

From a functional programming perspective, FSTs form a category here!

implicit object TransducerCategory extends Category[Transducer] {

  def id[A]: Transducer[Id, Id] = 
    new Transducer[Id, Id] {
      type S = Unit
      def initialStates = Set(())
      def finalStates = Set(())
      def next(s: S, a: A): Set((s, a))
    }
  
  def compose[A, B, C](
                       t2: Transducer[B, C], 
                       t1: Transducer[A, B]
                      ): Transducer[A, C] =
    new Transducer[A, C] {
      type S = (t1.S, t2.S)
      def initialStates = for {
        s1 <- t1.initialStates
        s2 <- t2.initialStates
      } yield (s1, s2)
      def finalStates = for {
        s1 <- t1.finalStates
        s2 <- t2.finalStates
      } yield (s1, s2)
      def next(s: S, a: A): (S, C) = {
        val (s1, s2) = s
        for {
          (s1_, b) <- t1.next(s1, a)
          (s2_, c) <- t2.next(s2, b)
        } yield ((s1_, s2_), c)
      }
    }
}

However, the output encoding of the next function, Set[(S, B)] is not ideal: what if each output is assigned a weight (weighted FSTs)? What if the output given a specific state and input form a distribution (stochastic FSTs)?

Note that the next function in compose is implemented elegantly through a monadic comprehension. Let’s abstract this monad out:

trait Transducer[F[_], -A, +B] {
    type S
    def initialStates: F[S]
    def finalStates: F[S]
    def next(s: S, a: A): F[(S, B)]
}

We call this F-transducer: the F higher-kinded type encapsulates the effect of moving the transducer state to another given an input! F could be

Let’s take the weighted case as an example, since weighted FSTs (WFSTs) are of utmost importance in a variety of domains, e.g. automatic speech recognition.

\[ W = \{ (a, r) \}; a \in A, r \in R \]

  type Weighted[A, R] = Map[A, R]

Weighted form a monad given a semiring on the weight type R (this is how WFSTs work in speech recognition): \[ \begin{align} \textrm{pure}(a) &= \left\{ (a, \mathbb{1}) \right\} \\\
\textrm{map}(W, f) &= \left\{ (f(a), r) \mid (a, r) \in W \right\} \\\
\textrm{flatMap}(W, f) &= \left\{ \left(b, \bigoplus_{(a,r)\in W} \bigoplus_{(b, r^\prime) \in f(a)} r \odot r^\prime \right) \middle| b \in B \right\} \end{align} \]

implicit def WeightedMonad(implicit R: Semiring[R]) 
  extends Monad[Weighted[_, R]] { 
    // let's pretend Scala has type lambda syntax here
  def pure[A](a: A) = Map((a, R.one))
  def map[A, B](wa: Map[A, R])(f: A => B) = 
    wa.map { case (a, r) => (f(a), r) }.toMap
  def flatMap[A, B](wa: Map[A, R])(f: A => Map[B, R]) = 
    (
      for {
        (a, r0) <- wa
        (b, r1) <- f(a)
      } yield (b, R.times(r0, r1))
    ).groupBy {
      case (b, r) => b
    }.mapValues(_.map(_._2).sum)
}

Hence we have the Weighted-Transducer, therefore we have a category over WFSTs. Essentially, we have the following implicit proving relations:

  Semiring[R] => Monad[Weighted[_, R]]
  Monad[F] => Category[Transducer[F, _, _]]

I think that this is really illuminating, as it shows the structure of FSTs through the lens of abstracted types.