GFlowNet: Flow Network based Generative Models for Non-Iterative Diverse Candidate Generation

Introduction

논문에서는 positive reward function $R(x)$가 주어졌을 때, terminal state $x$를 reward에 비례한 확률로 sampling하는 stochastic policy를 학습하는 문제를 다룬다.

기존 reinforcement learning은 expected return $R$을 최대화하는 방향으로 policy를 학습하기 때문에, 가장 높은 return을 가지는 single sequence of action에 모든 probability mass를 집중시키는 경향이 있다고 설명한다.

저자들은 drug discovery와 같은 black-box optimization 문제에서는 single mode가 아닌 reward function의 여러 mode에서 diverse candidate를 sampling하는 것이 더 중요하다고 주장한다.

특히 oracle 자체가 uncertain한 경우(cellular assay나 docking simulation처럼 더 정확한 평가의 cheap proxy인 경우), 한 round에서 다양한 후보군을 평가해야 information gain이 커지므로 diversity가 핵심이라고 이야기한다.

논문에서는 이러한 sampling 문제를 energy function을 generative model로 변환하는 문제로 재해석하고, MCMC가 가지는 느린 mixing 문제와 sequential RL이 가지는 local maxima 문제를 동시에 해결하는 GFlowNet을 제안한다.

Contribution

논문에서 제시한 주요 contribution은 다음과 같다.

  • Flow network 관점에서 generative process를 정의하고, local flow-matching condition(node에 들어오는 flow와 나가는 flow가 일치)을 통해 unnormalized probability distribution을 학습하는 GFlowNet을 제안하였다.
  • Flow-matching condition과 generated policy가 target reward function과 일치하는 것 사이의 link를 이론적으로 증명하고, off-policy property와 asymptotic convergence를 함께 보였다.
  • 기존 연구(Buesing et al., 2019)가 generative process를 tree로 가정하여 동일한 final state에 도달하는 multiple action sequence가 존재할 때 실패한다는 점을 분석하였다.
  • Synthetic data에서 single mode에 focus하는 대신 reward function의 distribution과 모든 mode를 modeling하는 것의 유용성을 보였다.
  • Large scale molecule synthesis domain에서 PPO 및 MCMC method와의 비교 실험을 통해 GFlowNet의 효과를 검증하였다.

Method

Problem Setup

논문에서는 discrete set $\mathcal{X}$의 element $x$를 sequential action을 통해 생성하는 policy $\pi(a \mid s)$를 다음 조건을 만족하도록 학습하는 것이 목표라고 설명한다.

\[\pi(x) \approx \frac{R(x)}{Z} = \frac{R(x)}{\sum_{x' \in \mathcal{X}} R(x')}\]

위 수식에서 $R(x) > 0$는 terminal state $x$에 대한 reward를 의미하고, $Z$는 partition function을 나타낸다. $\pi(x)$는 policy를 따라 trajectory를 sampling했을 때 terminal state가 $x$일 확률을 의미한다.

State space $\mathcal{S}$ 중 terminal state의 집합을 $\mathcal{X} \subset \mathcal{S}$로 정의하고, $\mathcal{A}$를 finite alphabet, $\mathcal{A}(s) \subseteq \mathcal{A}$를 state $s$에서 허용된 action의 집합으로 정의한다. Action sequence \(\vec{a} = (a_1, \ldots, a_h)\)를 state로 mapping하는 함수를 \(C : \mathcal{A}^{\ast} \to \mathcal{S}\)로 정의한다.

논문에서는 $C$가 bijective한 경우 generative process를 single root node에서 leaf로 향하는 tree로 visualize할 수 있지만, molecule처럼 동일한 graph를 multiple order로 describe할 수 있는 non-injective case에서는 directed acyclic graph(DAG)로 표현해야 한다고 설명한다.

Figure 1. A flow network MDP. Episodes start at source $$s_0$$ with flow $$Z$$. Like with SMILES strings, there are no cycles. Terminal states are sinks with out-flow $$R(s)$$. Exemplar state $$s_3$$ has parents $$\{(s,a) \mid T(s,a) = s_3\} = \{(s_1, a_2), (s_2, a_5)\}$$ and allowed actions $$A(s_3) = \{a_4, a_7\}$$. $$s_4$$ is a terminal sink state with $$R(s_4) > 0$$ and only one parent. The goal is to estimate $$F(s,a)$$ such that the flow equations are satisfied for all states: for each node, incoming flow equals outgoing flow.

Bias of Tree-based Approaches

논문에서는 기존 연구처럼 pseudo-value $\tilde{V}(s)$를 사용하여 policy를 정의하면 non-injective case에서 심각한 bias가 발생한다고 분석한다. Pseudo-value는 다음과 같이 정의한다.

\[\tilde{V}(s) = \sum_{\vec{b} \in \mathcal{A}^{\ast}(s)} R(s + \vec{b})\]

위 수식에서 \(\mathcal{A}^{\ast}(s)\)는 state \(s\)에서 허용된 continuation의 집합이고, \(s + \vec{b}\)는 state \(s\)에서 action sequence \(\vec{b}\)를 적용하여 도달한 state를 의미한다. 따라서 \(\tilde{V}(s)\)는 \(s\)로부터 reachable한 모든 state의 reward 합을 나타낸다.

이 pseudo-value를 사용하여 policy를 다음과 같이 정의한다고 가정한다.

\[\pi(a \mid s) = \frac{\tilde{V}(s + a)}{\sum_{b \in \mathcal{A}(s)} \tilde{V}(s + b)}\]

Proposition 1에 따르면 $C$가 non-injective이고 $C(\vec{a}_i) = x$를 만족하는 distinct action sequence가 $n(x)$개 존재할 때, 위 policy를 따르면 terminal state의 sampling 확률은 다음과 같이 표현된다고 설명한다.

\[\pi(x) = \frac{n(x) R(x)}{\sum_{x' \in \mathcal{X}} n(x') R(x')}\]

위 수식에서 $n(x)$는 $x$로 mapping되는 distinct action sequence의 개수를 의미한다.

논문에서는 molecule처럼 combinatorial space에서 $C$가 non-injective인 경우, trajectory length가 길어질수록 $n(x)$가 exponential하게 커지기 때문에 larger molecule이 smaller molecule보다 exponentially 더 자주 sampling되는 bias가 발생한다고 이야기한다.

이는 pseudo-value $\tilde{V}$가 MDP의 구조를 tree로 misinterpret하여 잘못된 $\pi(x)$를 생성하기 때문이라고 해석한다.

Flow Network Formulation

논문에서는 MDP를 tree가 아닌 flow network로 해석하여 이 문제를 해결한다고 제안한다.

Flow network는 single source인 root node $s_0$를 가지며 in-flow는 $Z$이다. 각 leaf $x$는 sink이며 out-flow가 $R(x) > 0$이다. $T(s, a) = s’$는 state-action pair $(s, a)$가 state $s’$로 이어진다는 것을 의미하고, $F(s, a)$는 node $s$와 node $s’ = T(s, a)$ 사이의 flow를 나타낸다. $F(s)$는 node $s$를 통과하는 total flow를 의미한다.

$C$가 bijective가 아니므로 root를 제외한 임의의 node는 multiple parent를 가질 수 있으며, 즉 $\lvert{(s, a) \mid T(s, a) = s’}\rvert \geq 1$이다.

State $s’$에 대한 in-flow는 모든 parent edge의 flow 합으로 정의한다.

\[F(s') = \sum_{s, a : T(s, a) = s'} F(s, a)\]

State $s’$의 out-flow는 해당 state에서 나가는 모든 edge의 flow 합으로 정의한다.

\[F(s') = \sum_{a' \in \mathcal{A}(s')} F(s', a')\]

위 수식에서 ${(s, a) : T(s, a) = s’}$는 $s’$의 모든 parent edge를 나타내고, $\mathcal{A}(s’)$는 state $s’$에서 허용된 action의 집합을 의미한다.

Interior node에 대해 $R(s) = 0$, leaf node에 대해 $\mathcal{A}(s) = \emptyset$이라는 convention을 사용하면, 논문에서는 in-flow와 out-flow가 같아야 한다는 flow consistency equation을 다음과 같이 통합한다.

\[\sum_{s, a : T(s, a) = s'} F(s, a) = R(s') + \sum_{a' \in \mathcal{A}(s')} F(s', a')\]

위 수식의 좌변은 in-flow, 우변은 out-flow와 sink로 빠져나가는 reward $R(s’)$의 합을 의미한다.

Policy from Flow

논문에서는 flow consistency equation이 만족되었을 때, 다음 policy가 $\pi(x) \propto R(x)$를 정확히 sampling함을 Proposition 2로 증명한다.

\[\pi(a \mid s) = \frac{F(s, a)}{F(s)}\]

위 수식에서 $F(s, a) > 0$는 allowed edge $(s, a)$의 flow를 의미하고, $F(s) = R(s) + \sum_{a \in \mathcal{A}(s)} F(s, a)$로 non-terminal node에서는 $R(s) = 0$, terminal node에서는 $F(x) = R(x) > 0$이 적용된다.

Proposition 2는 다음 세 가지 결과를 보인다고 설명한다.

  • 임의의 state $s$를 방문할 확률은 $\pi(s) = F(s) / F(s_0)$이다.
  • Source의 total flow는 partition function과 동일하며, $F(s_0) = \sum_{x \in \mathcal{X}} R(x)$가 성립한다.
  • Terminal state에서 $\pi(x) = R(x) / \sum_{x’ \in \mathcal{X}} R(x’)$가 성립한다.

증명은 induction으로 진행된다. $\pi(s_0) = 1$이 base statement로 자명하게 성립하고, parent state $s$들에 대해 statement가 성립한다고 가정하면 다음 식이 도출된다고 설명한다.

\[\pi(s') = \sum_{s, a : T(s, a) = s'} \frac{F(s, a)}{F(s)} \cdot \frac{F(s)}{F(s_0)} = \frac{\sum_{s, a : T(s, a) = s'} F(s, a)}{F(s_0)} = \frac{F(s')}{F(s_0)}\]

이후 $\sum_{x \in \mathcal{X}} \pi(x) = 1$을 사용하여 $F(s_0) = \sum_{x \in \mathcal{X}} R(x)$를 얻고, 이를 대입하여 $\pi(x) = R(x) / \sum_{x’ \in \mathcal{X}} R(x’)$를 도출한다고 설명한다.

이 결과는 $C$가 bijective이든 non-injective이든 동일하게 성립하므로, GFlowNet은 tree-based 방법이 가지는 bias 없이 reward에 비례한 sampling을 수행할 수 있다고 주장한다.

Flow Matching Objective

논문에서는 TD algorithm이 Bellman equation을 학습 objective로 변환하는 방식과 유사하게, flow consistency equation을 학습 objective로 변환한다고 설명한다.

가장 단순한 형태의 loss는 trajectory $\tau$에 대해 다음과 같이 정의한다.

\[\tilde{\mathcal{L}}_\theta(\tau) = \sum_{s' \in \tau \neq s_0} \left( \sum_{s, a : T(s, a) = s'} F_\theta(s, a) - R(s') - \sum_{a' \in \mathcal{A}(s')} F_\theta(s', a') \right)^2\]

위 수식에서 $F_\theta$는 neural network로 parameterize된 flow estimator를 의미하고, summation은 trajectory $\tau$ 내 root가 아닌 모든 state $s’$에 대해 수행된다.

논문에서는 high-dimensional space에서 $\mathcal{X}$의 cardinality가 exponential하면 root에 가까운 early state의 $F(s, a)$와 $F(s)$가 later state보다 exponentially 더 커지므로, neural network output으로 직접 사용하면 serious numerical issue가 발생한다고 지적한다.

이를 해결하기 위해 incoming/outgoing flow의 logarithm을 matching하는 다음 objective를 제안한다.

\[\mathcal{L}_{\theta, \epsilon}(\tau) = \sum_{s' \in \tau \neq s_0} \left( \log \left[ \epsilon + \sum_{s, a : T(s, a) = s'} \exp F^{\log}_\theta(s, a) \right] - \log \left[ \epsilon + R(s') + \sum_{a' \in \mathcal{A}(s')} \exp F^{\log}_\theta(s', a') \right] \right)^2\]

위 수식에서 $F^{\log}_\theta(s, a) = \log F(s, a)$는 log-scale flow estimator를 의미하고, $\epsilon$은 작은 flow의 logarithm을 회피하면서 large flow의 error에 더 큰 weight를 부여하기 위한 hyperparameter를 나타낸다.

논문에서는 log-scale에서의 matching이 incoming flow와 outgoing flow의 ratio를 1에 가깝게 만드는 것과 동등하다고 해석한다. 또한 $\epsilon$을 $R$이 가질 수 있는 smallest value에 가깝게 설정하면 top mode를 발견하는 데 더 큰 pressure를 부여할 수 있다고 설명한다.

Off-policy and Offline Property

논문에서는 GFlowNet의 중요한 특성으로 off-policy 및 offline 학습이 가능하다는 점을 Proposition 3로 증명한다.

저자들은 trajectory를 sampling하는 exploratory policy \(P\)가 consistent flow \(F^{\ast}\)로부터 정의되는 optimal policy \(\pi\)와 동일한 support를 가지고, 충분한 capacity를 가진 estimator(\(\exists \theta : F_\theta = F^{\ast}\))가 사용되면, expected training loss의 minimizer \(\theta^{\ast}\)가 \(F_{\theta^{\ast}} = F^{\ast}\)와 \(\mathcal{L}_{\theta^{\ast}}(\tau) = 0\)을 만족한다고 주장한다. 이때 학습된 policy도 다음을 만족한다.

\[\pi_{\theta^{\ast}}(a \mid s) = \frac{F_{\theta^{\ast}}(s, a)}{\sum_{a' \in \mathcal{A}(s)} F_{\theta^{\ast}}(s, a')}, \quad \pi_{\theta^{\ast}}(x) = \frac{R(x)}{Z}\]

이는 RL에서 asynchronous dynamic programming과 유사한 성질로, 모든 state가 asymptotically infinitely 많이 방문되기만 하면 수렴이 보장된다고 설명한다.

Experiment & Result

Hyper-grid Domain

논문에서는 partition function을 정확히 compute할 수 있는 작은 규모의 hyper-grid domain에서 GFlowNet의 동작을 검증한다고 설명한다.

State는 $n$-dimensional hypercubic grid의 cell이고, agent는 coordinate $i$를 increase시키는 action $a_i$만 선택할 수 있으며 stop action으로 trajectory를 종료한다. 동일한 coordinate에 도달하는 action sequence가 여러 개 존재하므로 이 MDP는 DAG 구조를 가진다고 설명한다.

Reward function은 다음과 같이 정의한다.

\[R(x) = R_0 + R_1 \prod_i \mathbb{I}(0.25 < \lvert x_i / H - 0.5 \rvert) + R_2 \prod_i \mathbb{I}(0.3 < \lvert x_i / H - 0.5 \rvert < 0.4)\]

위 수식에서 $0 < R_0 \ll R_1 < R_2$이고, $\mathbb{I}(\cdot)$는 indicator function을 나타낸다. 이 reward는 grid의 corner 근처에서만 의미 있는 값을 가지며 정확히 $2^n$개의 mode를 가진다. $R_0$를 0에 가깝게 설정할수록 problem이 artificially harder해지며, 탐색하기 undesirable한 region이 생긴다고 설명한다. 논문 실험에서는 $R_1 = 1/2$, $R_2 = 2$로 설정하고 $n = 4$, $H = 8$의 4-D hyper-grid를 사용하였다.

성능 측정 지표는 empirical L1 error $\mathbb{E}[\lvert p(x) - \pi(x) \rvert]$이며, $p(x) = R(x) / Z$는 closed-form으로 계산 가능하고 $\pi$는 repeated sampling과 frequency counting으로 추정한다고 설명한다.

Figure 2. Hypergrid domain. Changing the task difficulty $R_0$ to illustrate the advantage of GFlowNet over others. We see that as $R_0$ gets smaller, MCMC struggles to fit the distribution because it struggles to visit all the modes. PPO also struggles to find all the modes, and requires very large entropy regularization, but is robust to the choice of $R_0$. We plot means over 10 runs for each setting.

실험 결과, GFlowNet은 $R_0$ 값에 robust하며 낮은 L1 error를 달성하였다고 보고한다. Metropolis-Hastings-MCMC는 어떤 L1 error level을 달성하기 위해 GFlowNet보다 exponentially 더 많은 sample이 필요하며, $R_0$가 작아질수록 각 mode를 한 번씩 방문하는 데 훨씬 더 오래 걸린다고 분석한다.

PPO는 모든 mode를 reasonable한 시간 안에 발견하기 위해 entropy maximization term을 usual($\ll 1$)보다 훨씬 큰 0.5로 설정해야 했으며, random agent보다는 빠르게 mode를 발견하지만 GFlowNet보다는 훨씬 느리다고 설명한다. SAC에서도 similar or worse한 결과가 나왔다고 보고한다.

Generating Small Molecules

논문에서는 large-scale 환경으로 small drug molecule 생성 task를 수행한다고 설명한다. 이 환경은 최대 $10^{16}$개의 state와 state에 따라 100~2000개의 action을 가진다.

Molecule은 Jin et al. (2020)의 framework를 따라 predefined vocabulary of building block을 결합하여 부분적으로 생성되며, 결합된 building block은 junction tree를 형성한다고 설명한다. 이는 fragment-based drug design으로도 알려져 있다. Action space는 attachment site 선택과 block 선택의 product space이며, editing sequence를 stop하는 extra action이 존재한다. 동일한 molecule graph에 도달하는 action sequence가 여러 개 존재하고 edge removal action이 없으므로 cycle 없는 DAG MDP가 형성된다고 이야기한다.

Reward는 soluble epoxide hydrolase(sEH) protein에 대한 binding energy를 예측하는 pretrained proxy model로 계산된다고 설명한다. Proxy는 atom graph에 대한 MPNN으로 parameterize되고, flow predictor $F_\theta$는 MARS와 유사하게 junction tree graph(graph of blocks)에 대한 MPNN으로 parameterize된다.

Proxy는 300k molecule의 semi-curated semi-random dataset으로 pretrain되어 test MSE 0.6에 도달하며, docking score(Trott and Olson, 2010)를 대부분 0과 10 사이로 renormalize하여 $R(x) > 0$이 되도록 한다고 설명한다. Sampling 시에는 $\pi(a \mid s)$를 따르는 확률 0.95와 uniform distribution over allowed actions를 따르는 확률 0.05의 mixture를 exploratory policy로 사용한다.

Figure 3. Empirical density of rewards. We verify that GFlowNet is consistent by training it with $R^\beta$, $\beta = 4$, which has the hypothesized effect of shifting the density to the right.

저자들은 reward function을 $R(x)^\beta$로 training했을 때 sampling distribution이 high-reward 방향으로 shift되는 것을 확인하였다고 보고한다. 이는 GFlowNet이 reasonable한 policy $\pi$를 학습하고 있음을 보여준다고 해석한다. MARS도 동일한 shift를 보이지만 GFlowNet이 same $\beta$ value에서 더 많은 high reward molecule을 발견하는데, 이는 MARS가 MCMC method로서 동일한 distribution에 결국 수렴하더라도 더 오랜 시간이 걸리기 때문이라고 설명한다.

Figure 4. The average reward of the top-$k$ as a function of learning (averaged over 3 runs). Only unique hits are counted. Note the log scale. Our method finds more unique good molecules faster.

Top-$k$ unique molecule(SMILES 기준)의 average reward를 비교한 결과, GFlowNet은 MARS, PPO, JT-VAE+BO보다 더 빠르게 high-reward molecule을 발견하였다고 보고한다. PPO는 일정 시점 이후 plateau에 도달하는데, RL이 strongly regularize되지 않으면 good enough trajectory에 만족하기 때문이라고 분석한다. JT-VAE+BO는 expensive Gaussian Process 때문에 동일한 compute time 안에 약 $10^3$개의 molecule만 생성할 수 있어 성능이 낮다고 설명한다.

저자들의 best run에서는 reward가 8 이상인 unique molecule 2339개를 발견하였으며, 그 중 dataset에 포함된 것은 39개뿐이라고 보고한다(proxy dataset의 maximum reward는 10, reward 8 이상은 233개).

Figure 5. Number of diverse Bemis-Murcko scaffolds found above reward threshold $T$ as a function of the number of molecules seen. Left, $T = 7.5$. Right, $T = 8$.

Bemis-Murcko scaffold를 기준으로 mode의 개수를 측정한 결과, GFlowNet은 $R > 8$인 mode를 1500개 이상 발견하였지만 MARS는 100개 미만이라고 보고한다. Top 1000 sample의 average pairwise Tanimoto similarity는 GFlowNet이 $0.44 \pm 0.01$, PPO가 $0.62 \pm 0.03$, MARS가 $0.59 \pm 0.02$로 측정되어, MCMC baseline과 RL baseline이 less diverse한 candidate를 생성한다고 설명한다.

Multi-Round Experiments

논문에서는 true oracle $\mathcal{O}$ 호출에 limited budget이 있는 active learning setting에서 GFlowNet의 성능을 검증한다고 설명한다. Limited dataset $D_0$의 $(x, R(x))$ pair로 proxy $M$을 초기화하고, generative model $\pi_\theta$로 batch $B = {x_1, \ldots, x_k}$를 sampling하여 oracle로 평가한 뒤, newly acquired labeled batch로 proxy $M$을 update하는 process를 $N$ iteration 반복한다.

Figure 6. The top-$k$ return (mean over 3 runs) in the 4-D Hyper-grid task with active learning. GFlowNet gets the highest return faster.

4-D hyper-grid task에서는 proxy로 Gaussian Process를 사용하고 \(\lvert D_0 \rvert = 512\)로 설정한다. Top-\(k\) Return은 \(\text{mean}(\text{top-}k(D_i)) - \text{mean}(\text{top-}k(D_{i-1}))\)로 정의되며 \(k = 10\)을 사용한다. GFlowNet은 baseline 대비 initial set에 대한 return 측면에서 일관되게 outperform한다고 보고한다. Final round 종료 시점의 top-\(k\) point 사이 mean pairwise L2-distance는 GFlowNet \(0.83 \pm 0.03\), MCMC \(0.61 \pm 0.01\), PPO \(0.51 \pm 0.02\)로 측정되어, true oracle 없이도 mode를 capture하는 능력과 multi-round setting에서 diversity의 중요성을 보여준다고 해석한다.

Figure 7. The top-$k$ docking reward (mean over 3 runs) in the molecule task with active learning. GFlowNet consistently generates better samples.

Molecule discovery task에서는 MPNN proxy를 AutoDock(Trott and Olson, 2010)의 docking score를 예측하도록 초기화하고 $\lvert D_0 \rvert = 2000$으로 설정한다. 각 round 끝에 200개의 molecule을 생성하여 AutoDock으로 평가하고 proxy를 update한다고 설명한다.

GFlowNet은 initial set $D_0$보다 significantly 높은 energy를 가진 molecule을 발견하였고, MARS와 random acquisition을 일관되게 outperform한다고 보고한다. Initial set의 mean pairwise Tanimoto similarity는 0.60이며, final round 종료 시점에는 GFlowNet $0.54 \pm 0.04$, MARS $0.64 \pm 0.03$으로 측정되어 GFlowNet의 diversity 우위가 유지된다고 설명한다. PPO training은 unstable하게 consistently diverge하여 결과가 reported되지 않았다고 이야기한다.

Limitation

논문에서는 GFlowNet이 TD-based method와 동일하게 bootstrapping에 의한 optimization challenge를 가질 수 있으며, 이로 인해 성능이 제약될 수 있다고 인정한다. 이는 Deep RL에서 알려진 challenge라고 설명한다.

또한 drug discovery 같은 application에서는 각 mode 주변에서 sampling하는 것이 이미 important advantage이지만, generated sample을 local maxima로 refine하면서도 batch의 diversity를 유지하기 위해 generative approach와 local optimization을 결합하는 방향은 future work로 남아있다고 이야기한다.

추가로 논문은 deterministic generative setting만 다루므로, stochastic environment로의 확장은 RL framework가 다루는 영역이며 별도의 연구가 필요하다고 볼 수 있다.

Reference

Bengio, Emmanuel, et al. “Flow network based generative models for non-iterative diverse candidate generation.” Advances in neural information processing systems 34 (2021): 27381-27394.