There has been a lot of research into combining tree search with neural networks to allow neural networks to perform more complex planning. We have seen this with research like ReBeL, MuZero and AlphaGo. All of them using monte carlo tree search combined with neural networks for allowing a neural network to perform planning in multiple steps into the future.
What is more interesting is going full neural network, and planning in latent space which brings us to Value Prediction Networks.
It feels like the paper on VPN got missed by many as very few mention it, but the same people know about MuZero. That said, I think it’s a nice paper worth exploring.
The model has 4 modules with distinct tasks, all these modules are trained together with a single optimizer.
To break it down:
The planning algorithm is a Q-value algorithm, personally I like seeing algorithms in Python as it's more complete than pseudo-code and the above could be expressed like this in Python:
class Planner:
def plan(self, state):
return self.get_best_actions(state, max_depth=4)[0]
def q_planning(self, state, action, depth):
next_state = self.transition(state, action)
reward, gamma = self.outcome(state, action)
value = self.value(next_state)
if depth == 1:
return reward + gamma * value
else:
best_actions = self.get_best_actions(next_state)
best_value = float("-inf")
for a_t in best_actions:
best_value = max(
best_value,
self.q_planning(
state=next_state, action=a_t, depth=depth - 1)
)
return reward + gamma * (depth - 1) / depth * best_value
def get_best_actions(self, state, max_depth=1):
def score(action):
return self.q_planning(
state=state, action=action, depth=max_depth)
actions = range(self.possible_actions)
ranked = sorted(actions, key=score, reverse=True)
b_best = 1
return ranked[:b_best]
The algorithm should be trained using a $\epsilon$-greedy policy. The model can be trained by predictions based on the trajectory of earlier history ( $\mathcal{x_t}, \mathcal{a_t}, \mathcal{r_t}, \mathcal{\gamma_t} $).
Our objective is for the model to be better at predicting the Q-value. The state representation, and transition modules is hard to optimize for directly in the loss function. However, they can be optimized indirectly by optimizing the value and outcome modules.
The outcome module is probably the easiest to optimize for, and most intuitive to create a optimization goal for. The predicted $r$ and $\gamma$ should be as accurate as possible given the associated state and action.
The value function is also quite intuitive for people that have been looking at RL before. We optimize the value function module to match the value of the state ($R_t = r_t + \gamma_t R_t$).
This can all be expressed in code like the following:
class Optimizer:
def loss(self, trajectory, K=3):
for index in range(len(trajectory) - K):
state = self.encode(trajectory[index].state)
R = self.get_reward(trajectory[index:])
loss = sum(
(R - self.value(state, k)) ** 2
+ self.get_reward_gamma_error(trajectory, index, k)
for k in range(K)
)
loss.backwards()
def get_reward_gamma_error(self, trajectory, index, k):
state = self.encode(trajectory[index].state)
for i in range(k):
state = self.transition(
state, trajectory[index + i].action)
entry = trajectory[index + k]
pred_reward, pred_gamma = self.outcome(state, entry.action)
return (
(entry.reward - pred_reward) ** 2
+ (entry.gamma - pred_gamma) ** 2
)
def value(self, state, depth):
if depth == 0:
return self.value_module(state)
a = (1 / (depth + 1)) * self.value_module(state)
b = (depth / (depth + 1)) * Planner().q_planning(state)
return a + b
def get_reward(self, trajectory):
last = trajectory[-1]
R = 0 if last.is_terminal else Planner().plan(last.state)
for entry in reversed(trajectory):
R = entry.reward + entry.gamma * R
return R
It’s a nice paper, and I recommend everyone to read it. Especially since it’s quite well written and the ideas are made very accessible.