✍️ @Arushi Somani June 20, 2024 5:24 PM (PDT)

<aside> 💡 This is a WIP! A reader will find many TODOs and incomplete sections. WIP Repo: https://github.com/somaniarushi/code_mcts

</aside>

Introduction

There’s been a recent wave of papers regarding adding search at inference time to LLMs or search in a solution space as a tool for data collection. This line of work was started by Chain of Thought Prompting [1] which was demonstrated to improve reasoning in large language models. An augmentation of prompting was to allow models to search through a tree of thoughts, from which the best path would be selected [2].

There has been a recent explosion of effort in combining traditional RL methods for search with large language models. We’ve heard the same from the CEO of Deepmind [3]. We’ve seen an explosion in the recent line of work combining Monte Carlo Tree Search with LLM inference to be able to create targetted data as well as make inference time search more efficient [4][5][6][7]. Monte Carlo Tree Search has been demonstrated to boost reasoning [8]. There’s been some work in the combination of code generation and tree search as well [9]. This work— “Planning with Large Language Models for Code Generation”— is what we’ll attempt to implement today.

This implementation aims to reproduce the results mentioned — at its core, that MCTS search makes models better at HumanEval. In particular, we’ll be taking inspiration from https://arunpatro.github.io/blog/mcts/ [10], an excellent introduction to how to add MCTS to LLMs at decoding.

Methodology Description

In this section, we describe the methodology used for setting up our control and experiment itself.

Baseline

We use standard decoding as our baseline. For a fair comparison, we attempt to run evaluations at constant compute. To do this, we’ll be using pass@k evaluation.

Pass@k Evaluation can be defined as evaluating k generations from a large language models, such that the model is given a pass if any of the k generations are marked as correct.

We run humaneval at various pass@k values on Llama 3 8B pre-trained and Llama3 8B Chat-Tuned to serve as our baselines. For hyperparameters, we use top_k=40 and temperature=0.5.

Monte Carlo Tree Decoding Algorithm

To use Monte Carlo Tree Search for LLM inference, we design the following tactic: We start the Tree Search with a node comprising prompt $p$. We assign the probability of 1 to this node. Then, we begin the process, which has four steps: Selection, Expansion, Evaluation, and Backprop.

Step 1: Selection

For selection, we parse through the list of nodes and, until a leaf is reached, select a child node based on the following:

$$ \text{UBC} = \text{argmax}c V(c) + \beta * P(c) * \sqrt{\frac{\log({\text{visit}(n))} }{1 + \text{visit}(c)} } \text{ }\forall c \in \text{n.children}\\\text{where } \beta = \log\frac{\text{visit}(n) + c{\text{base}} + 1}{c_{\text{base}}}\text{ } + c $$

where $V(c)$ is the value assigned to the child node through backpropagation, $P(c)$ is the log probability of the token generated for that node. $\text{visit}(n)$ is the number of times a node $n$ has been visited.

This is an alternative implementation of Upper Bound Confidence traditionally used in Monte Carlo Tree Search.