Here I have restricted the space to include only the states and actions resulting from initially choosing $X \rightarrow Y$ or $Y \rightarrow Z$ and leading to the DAG with edges $(X \rightarrow Y, Y \rightarrow Z, X \rightarrow Z)$. It is clear from the diagram that there can be multiple action sequences leading to the same state. An action sequence can also be thought of as a trajectory $\tau$, such that $\tau = (s_0 \rightarrow \cdots \rightarrow s_T \rightarrow s_f)$ contains the history of how the final object $s_T$ was constructed step-by-step. You will also notice a special state $s_f$. This state is known as the sink and is reached whenever sampling terminates. The sink is not part of the model as such butâas we will see belowâis helpful in formalizing a training objective.
The main idea behind GFlowNets is to imagine $Z = \sum_{x \in \mathcal{X}} R(x)$ units of water flowing from the source $s_0$ along the various trajectories and exiting the system through the sink. Crucially, for a state $s$ representing an object $x$, the amount of flow on the path $s \rightarrow s_f$ should equal $R(x)$ while any excess flow should pass onto other states. If we can achieve this, the relative frequency of terminating at $s$ for any unit of flow will be $\frac{R(x)}{Z}$, and we will thus be sampling $x$ in proportion to its reward. The difficulty is to formalize a loss for which the optimal solution obeys this restriction. To do so, we need to introduce the flow matching conditions.
Let $F(s \rightarrow sâ)$ denote the flow along the path $s \rightarrow sâ$, and let $\text{Pa}(s)$ and $\text{Ch}(s)$ denote the parent states and child states of $s$, respectively. In order for a flow to be valid, the following condition needs to hold for all states other than $s_0$ and $s_f$ 2:
âŹâŹ \sum_{sâ \in \text{Pa}(s)} F(sâ \rightarrow s) = \sum_{sââ \in \text{Ch}(s)} F(s \rightarrow sââ) + R(s) \tag{1} âŹâŹ
For convenience, we assume $s_f \notin \text{Ch}(s)$ for any $s$. This means that the equality enforces that the terminal flow $F(s \rightarrow s_f)$ must equal $R(s)$ while any excess flow is distributed among the child states in $\text{Ch}(s)$. There are multiple ways to do this. For instance, you might introduce a flow estimator $F_\phi$ with parameters $\phi$ and minimize the flow matching loss
âŹâŹ \mathcal{L}_{FM}(s) = \left( \log \frac{\sum_{sâ \in \text{Pa}(s)} F_\phi (sâ \rightarrow s)}{\sum_{sââ \in \text{Ch}(s)} F_\phi (s \rightarrow sââ) + R(s)} \right)^2. âŹâŹ
An alternative (but equivalent) approach is to formulate the problem in terms of a forward policy $P_F(s_{t+1} \mid s_t)$ outputting a distribution over states reachable from the state at time $t$ (you could also view this as a distribution over $\mathcal{A}(s_t)$) as well as an estimate $Z_\theta$ of the total flow originating from $s_0$. Further, if we introduce a backward policy $P_B(s_t \mid s_{t+1})$, we can formulate what is known as the trajectory balance loss 3:
âŹâŹ \mathcal{L}_{TB}(\tau) = \left( \log \frac{Z_\theta \prod_{t=1}^T P_F(s_{t+1} \mid s_t)}{R(s_T) \prod_{t=1}^T P_B(s_t \mid s_{t+1})} \right) ^2 : \tau = (s_0 \rightarrow \cdots \rightarrow s_T \rightarrow s_f) âŹâŹ
To understand the trajectory balance loss, consider the detailed balance condition, which simply states that $F(s \rightarrow sâ)$ can be expressed both as a fraction of the total flow $F(s)$ through $s$ and as a fraction of the total flow $F(sâ)$ through $sâ$:
âŹâŹ P_F(sâ \mid s)F(s) = P_B(s \mid sâ)F(sâ) \tag{2} âŹâŹ
If this is not intuitive, imagine that Alice and Bob each have 10 liters of water. Alice hands 8 liters to Charlie while Bob hands him 5 liters. The water âflowing betweenâ Alice and Charlie is $0.8 \cdot 10$ but can also be expressed as $\frac{8}{13} \cdot 13$, since Charlie receives 13 liters in total. When $P_F$ and $P_B$ agree in this way, they satisfy the detailed balance condition.
Inspired by this, the trajectory balance loss replaces the flow $F(sâ)$ on the right-hand side in (2) with the reward $R(sâ)$. Note that this is only done with respect to the last state of a trajectory (i.e. $s_f$), since this is the only state to which the flow from $s_T$ should always exactly equal the reward $R(s_T)$. Interestingly, the authors argue that there is a unique flow satisfying the flow matching condition in (1) regardless of the choice of $P_B$. We can thus safely set $P_B$ to a uniform distribution over the parent states and focus our attention solely on training $Z_\theta$ and $P_F$.
We will now follow in the footsteps of Bengio et al. 2 and implement a GFlowNet to model a two-dimensional grid environment in which each coordinate has a corresponding reward. $s_0$ will represent the upper-left-hand coordinate, and each action will move down or right of the current coordinate. For a grid size of $N=16$, the reward âenvironmentâ looks like this:
As can be seen from the image, the reward function has four modes separated by âdead zonesâ with very low reward. The lower the reward, the harder it will be for the model to explore the environment. In the extreme case of the reward being 0 outside of the modes, the model will be unable to explore the environment, as no flow will be able to pass through the dead zones. As long as the reward is positive, however, the model will be able to explore the entire environment given enough time.
We are going to implement $P_F$ as a small MLP taking as input an $N^2$-dimensional one-hot vector indicating the current position (state), and outputting a distribution over possible actions: $\text{Down}$, $\text{Right}$ or $\text{Terminate}$. When the current position is at the bottom or right edge of the environment, the invalid action of $\text{Down}$ or $\text{Right}$ will be masked by setting its probability to 0. Finally, we sample an action according to this distribution, update the state accordingly, input the new state to $P_F$, etc. This goes on until the model chooses the $\text{Terminate}$ action after which we compute the trajectory balance loss and update $Z_\theta$ and $P_F$. Rinse and repeat.
Letâs start with implementing the forward and backward policy.
The most interesting part here is BackwardPolicy
. Since we use a fixed backward policy, the class does not need to inherit from nn.Module
. When given a state, the __call__
method simply defaults to assigning $0.5$ probability to having arrived at the state from either of the two parent states (top or left). However, in the case where there is only one parent state (at the left and top edge of the environment), we set the probability to $1$ for the single parent state.
We also define a Grid
class responsible for masking invalid actions, taking state-action pairs and outputting updated states, and computing state rewards.
Finally, we implement the actual GFlowNet.
The sample_states
method takes as input the initial states s0
and loops until all states have terminated. The Log
class is responsible for logging the current estimate of the total flow, the reward for each sample, and the forward and backward probabilities encountered along the trajectory for each sample. These are logged, since they are required to compute the trajectory balance loss. I wonât go into detail with the implementation of Log
here but feel free to check it out on GitHub.
We are now ready to train the GFlowNet.
Once the model has been trained, we can sample a final batch of $10,000$ samples from the sample_states
method, this time without needing the logged data.
As you can see, the relative frequency of each sample from the trained model (left) is roughly proportionate to its corresponding reward (right). Of course, this is a simple toy problem that serves only as a proof of concept. For an example on how GFlowNets can be applied to more real-world problems, see for instance the recent work of Deleu et al. 4 on using GFlowNets for Bayesian structure learning.
GFlowNets are a new area of research and time will tell how useful they will turn out to be. To me, they seem like a promising tool for problems such as drug discovery and causal inference, where the objects of interest are discrete (i.e. molecules and causal graphs) and where the prediction of oracle models or the evaluation of likelihoods can serve as reward functions. If you are interested in more of the thoughts behind GFlowNets, I highly recommend the MLST interview with Yoshua Bengio.
Bengio, Yoshua and Lahlou, Salem and Deleu, Tristan and Hu, Edward J. and Tiwari, Mo and Bengio, Emmanuel (2021). âGFlowNet Foundationsâ ↩
Bengio, Emmanuel and Jain, Moksh and Korablyov, Maksym and Precup, Doina and Bengio, Yoshua (2021). âFlow Network based Generative Models for Non-Iterative Diverse Candidate Generationâ ↩ ↩2
Malkin, Nikolay and Jain, Moksh and Bengio, Emmanuel and Sun, Chen and Bengio, Yoshua. âTrajectory balance: Improved credit assignment in GFlowNetsâ ↩
Deleu, Tristan and GĂłis, AntĂłnio and Emezue, Chris and Rankawat, Mansi and Lacoste-Julien, Simon and Bauer, Stefan and Bengio, Yoshua (2022). âBayesian Structure Learning with Generative Flow Networksâ ↩