MuZero: DeepMind MCTS (mctx) Integration
Objective
Replace custom MCTS logic with @google-deepmind/mctx.gumbel_muzero_policy, a stable, optimized Monte Carlo Tree Search implementation used in MuZero-style agents. Designed for latent space search with learned dynamics.
Library
- Name:
@google-deepmind/mctx - Core function:
gumbel_muzero_policy(...) - Backend: JAX
- Batch support: Yes
- Gumbel exploration: Built-in
Function Signatures
mctx.gumbel_muzero_policy(...)
output = mctx.gumbel_muzero_policy(
params=...,
rng_key=jax.random.PRNGKey(seed),
root_fn=...,
recurrent_fn=...,
num_simulations=50,
max_depth=5,
qtransform=mctx.QTransform.FUNNEL,
)
Required Components
root_fn(params, rng_key) → RootFnOutput
Initial inference at the root node. Must return:
mctx.RootFnOutput(
prior_logits: Float[Array, "B A"],
value: Float[Array, "B"],
embedding: Float[Array, "B latent_dim"],
)
prior_logits: raw action logits from the prediction headvalue: scalar value estimate of the rootembedding: latent state output of representation modelh(obs)
recurrent_fn(params, state, action) → RecurrentFnOutput
Transition and prediction step. Must return:
mctx.RecurrentFnOutput(
reward: Float[Array, "B"],
value: Float[Array, "B"],
embedding: Float[Array, "B latent_dim"],
prior_logits: Float[Array, "B A"],
)
reward: predicted scalar reward from the transitionvalue: predicted value of the resulting stateembedding: next latent state from dynamics modelprior_logits: predicted action logits at the new state
Outputs from gumbel_muzero_policy
GumbelMuZeroPolicyOutput(
action: Int[Array, "B"],
logits: Float[Array, "B A"],
root_value: Float[Array, "B"],
)
action: action sampled from the improved search policylogits: improved policy distribution from tree searchroot_value: predicted value of the current state
Integration Procedure
1. Define root_fn
def root_fn(params, rng_key):
embedding = model.representation(params, obs)
logits, value = model.prediction(params, embedding)
return mctx.RootFnOutput(prior_logits=logits, value=value, embedding=embedding)
2. Define recurrent_fn
def recurrent_fn(params, embedding, action):
next_embedding, reward = model.dynamics(params, embedding, action)
logits, value = model.prediction(params, next_embedding)
return mctx.RecurrentFnOutput(
reward=reward,
value=value,
embedding=next_embedding,
prior_logits=logits,
)
3. Replace existing MCTS logic with:
mcts_output = mctx.gumbel_muzero_policy(
params=model_params,
rng_key=jax.random.PRNGKey(seed),
root_fn=root_fn,
recurrent_fn=recurrent_fn,
num_simulations=50,
max_depth=5,
)
Extract:
- mcts_output.action: action to execute
- mcts_output.logits: improved policy (store for training)
- mcts_output.root_value: value estimate for bootstrapping
Training Notes
- Store
logitsfrom MCTS, not raw model policy, as training target - Use
root_valuefor value loss or bootstrapped targets - Do not backpropagate through
mctx— only through the model rewardandvalueshould be scalars per sample- All inputs and outputs must be JAX arrays (
jax.numpy)
Debugging / Validation
- Ensure all tensor shapes match:
(B,),(B, A),(B, latent_dim) - Start with 1–2 simulations and verify logit outputs
- Validate value estimates are within expected scale
- Confirm MCTS logits and selected actions vary with Gumbel noise
Final Checklist
| Task | Complete |
|---|---|
representation, dynamics, and prediction defined |
☐ |
root_fn returns logits, value, embedding |
☐ |
recurrent_fn returns reward, value, next embedding, logits |
☐ |
Custom MCTS fully replaced with mctx.gumbel_muzero_policy |
☐ |
| Replay buffer stores search logits and root value | ☐ |
| Environment compatibility validated | ☐ |
References
- Library: https://github.com/google-deepmind/mctx
- Docstring:
help(mctx.gumbel_muzero_policy)in code