← Back to Home

Planning in latent space

Written: February 13, 2022

Published: January 11, 2026: The article was originally written for my old blog back in 2022, but I also wanted to have it on my new blog. Even with some of its flaws. There is more completed code for the implementation here.

There has been a lot of research into combining tree search with neural networks to allow neural networks to perform more complex planning. We have seen this with research like ReBeL, MuZero and AlphaGo. All of them using monte carlo tree search combined with neural networks for allowing a neural network to perform planning in multiple steps into the future.

What is more interesting is going full neural network, and planning in latent space which brings us to Value Prediction Networks.

Value prediction networks

It feels like the paper on VPN got missed by many as very few mention it, but the same people know about MuZero. That said, I think it’s a nice paper worth exploring.

Architecture

The model has 4 modules with distinct tasks, all these modules are trained together with a single optimizer.

  1. First we have the encoding module, which takes the raw state ($\mathcal{S}$) and makes it into an abstract state ($\mathcal{S'}$). All the other components then only have to deal with $\mathcal{S'}$.
  2. For each abstract state ($\mathcal{S'}$) we need a way to value the state, i.e a value function. So this module is trained to give the value of an abstract state.
  3. Given an $\mathcal{S'}$ and action ($\mathcal{a}$) the outcome module will return the reward ($\mathcal{R}$) and also how the reward should be discounted ($\gamma$).
  4. The final part is the transition module, which takes the abstract state and a given action then it produces the new state given the action in the given state. I.e $f(\mathcal{S'}, a) = \mathcal{S''}$

Planning algorithm

To break it down:

  1. First it converts a state into an abstract state.
  2. It iterates over all possible actions from an abstract state.
  3. Based on the different modules we can calculate the Q-value, and we try to maximize this and explore states over multiple depths.

The planning algorithm is a Q-value algorithm, personally I like seeing algorithms in Python as it's more complete than pseudo-code and the above could be expressed like this in Python:


class Planner:
    def plan(self, state):
        return self.get_best_actions(state, max_depth=4)[0]

    def q_planning(self, state, action, depth):
        next_state = self.transition(state, action)
        reward, gamma = self.outcome(state, action)
        value = self.value(next_state)

        if depth == 1:
            return reward + gamma * value
        else:
            best_actions = self.get_best_actions(next_state)
            best_value = float("-inf")
            for a_t in best_actions:
                best_value = max(
                    best_value,
                    self.q_planning(
                        state=next_state, action=a_t, depth=depth - 1)
                )
            return reward + gamma * (depth - 1) / depth * best_value

    def get_best_actions(self, state, max_depth=1):
        def score(action):
            return self.q_planning(
                state=state, action=action, depth=max_depth)
        
        actions = range(self.possible_actions)
        ranked = sorted(actions, key=score, reverse=True)
        
        b_best = 1
        return ranked[:b_best]

Learning and optimizing

The algorithm should be trained using a $\epsilon$-greedy policy. The model can be trained by predictions based on the trajectory of earlier history ( $\mathcal{x_t}, \mathcal{a_t}, \mathcal{r_t}, \mathcal{\gamma_t} $).

Our objective is for the model to be better at predicting the Q-value. The state representation, and transition modules is hard to optimize for directly in the loss function. However, they can be optimized indirectly by optimizing the value and outcome modules.

The outcome module is probably the easiest to optimize for, and most intuitive to create a optimization goal for. The predicted $r$ and $\gamma$ should be as accurate as possible given the associated state and action.

The value function is also quite intuitive for people that have been looking at RL before. We optimize the value function module to match the value of the state ($R_t = r_t + \gamma_t R_t$).

This can all be expressed in code like the following:


class Optimizer:
    def loss(self, trajectory, K=3):
        for index in range(len(trajectory) - K):
            state = self.encode(trajectory[index].state)
            R = self.get_reward(trajectory[index:])

            loss = sum(
                (R - self.value(state, k)) ** 2
                + self.get_reward_gamma_error(trajectory, index, k)
                for k in range(K)
            )
            loss.backwards()

    def get_reward_gamma_error(self, trajectory, index, k):
        state = self.encode(trajectory[index].state)
        for i in range(k):
            state = self.transition(
                state, trajectory[index + i].action)

        entry = trajectory[index + k]
        pred_reward, pred_gamma = self.outcome(state, entry.action)

        return (
            (entry.reward - pred_reward) ** 2
            + (entry.gamma - pred_gamma) ** 2
        )

    def value(self, state, depth):
        if depth == 0:
            return self.value_module(state)
        a = (1 / (depth + 1)) * self.value_module(state)
        b = (depth / (depth + 1)) * Planner().q_planning(state)
        return a + b

    def get_reward(self, trajectory):
        last = trajectory[-1]
        R = 0 if last.is_terminal else Planner().plan(last.state)
        for entry in reversed(trajectory):
            R = entry.reward + entry.gamma * R
        return R

Conclusions

It’s a nice paper, and I recommend everyone to read it. Especially since it’s quite well written and the ideas are made very accessible.