AUXteam commited on
Commit
db83589
·
verified ·
1 Parent(s): 083fccb

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ai_scientist/fewshot_examples/132_automated_relational.pdf filter=lfs diff=lfs merge=lfs -text
37
+ ai_scientist/fewshot_examples/2_carpe_diem.pdf filter=lfs diff=lfs merge=lfs -text
38
+ ai_scientist/fewshot_examples/attention.pdf filter=lfs diff=lfs merge=lfs -text
ai_scientist/__init__.py ADDED
File without changes
ai_scientist/fewshot_examples/132_automated_relational.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "review": "{\n \"Summary\": \"The paper provides an interesting direction in the meta-learning field. In particular, it proposes to enhance meta learning performance by fully exploring relations across multiple tasks. To capture such information, the authors develop a heterogeneity-aware meta-learning framework by introducing a novel architecture--meta-knowledge graph, which can dynamically find the most relevant structure for new tasks.\",\n \"Strengths\": [\n \"The paper takes one of the most important issues of meta-learning: task heterogeneity. For me, the problem itself is real and practical.\",\n \"The proposed meta-knowledge graph is novel for capturing the relation between tasks and addressing the problem of task heterogeneity. Graph structure provides a more flexible way of modeling relations. The design for using the prototype-based relational graph to query the meta-knowledge graph is reasonable and interesting.\",\n \"This paper provides comprehensive experiments, including both qualitative analysis and quantitative results, to show the effectiveness of the proposed framework. The newly constructed Art-Multi dataset further enhances the difficulty of tasks and makes the performance more convincing.\"\n ],\n \"Weaknesses\": [\n \"Although the proposed method provides several ablation studies, I still suggest the authors conduct the following ablation studies to enhance the quality of the paper: (1) It might be valuable to investigate the modulation function. In the paper, the authors compare sigmoid, tanh, and Film layer. Can the authors analyze the results by reducing the number of gating parameters in Eq. 10 by sharing the gate value of each filter in Conv layers? (2) What is the performance of the proposed model by changing the type of aggregators?\",\n \"For the autoencoder aggregator, it would be better to provide more details about it, which seems not very clear to me.\",\n \"In the qualitative analysis (i.e., Figure 2 and Figure 3), the authors provide one visualization for each task. It would be more convincing if the authors can provide more cases in the rebuttal period.\"\n ],\n \"Originality\": 3,\n \"Quality\": 3,\n \"Clarity\": 3,\n \"Significance\": 4,\n \"Questions\": [\n \"Please address and clarify the cons above.\"\n ],\n \"Limitations\": [\n \"My major concern is about the clarity of the paper and some additional ablation models (see cons below). Hopefully the authors can address my concern in the rebuttal period.\"\n ],\n \"Ethical Concerns\": false,\n \"Soundness\": 3,\n \"Presentation\": 3,\n \"Contribution\": 3,\n \"Overall\": 7,\n \"Confidence\": 5,\n \"Decision\": \"Accept\"\n}"
3
+ }
ai_scientist/fewshot_examples/132_automated_relational.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a29ed4d84f6be5b9547097c2bc8bd57bfe197e91dc8f4ec9bcde6b545e7abe59
3
+ size 1348476
ai_scientist/fewshot_examples/132_automated_relational.txt ADDED
@@ -0,0 +1,1190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOMATED RELATIONAL META-LEARNING
2
+
3
+ **Anonymous authors**
4
+ Paper under double-blind review
5
+
6
+ ABSTRACT
7
+
8
+ In order to efficiently learn with small amount of data on new tasks, meta-learning
9
+ transfers knowledge learned from previous tasks to the new ones. However, a
10
+ critical challenge in meta-learning is the task heterogeneity which cannot be well
11
+ handled by traditional globally shared meta-learning methods. In addition, current
12
+ task-specific meta-learning methods may either suffer from hand-crafted structure
13
+ design or lack the capability to capture complex relations between tasks. In this
14
+ paper, motivated by the way of knowledge organization in knowledge bases, we
15
+ propose an automated relational meta-learning (ARML) framework that automatically extracts the cross-task relations and constructs the meta-knowledge graph.
16
+ When a new task arrives, it can quickly find the most relevant structure and tailor
17
+ the learned structure knowledge to the meta-learner. As a result, the proposed
18
+ framework not only addresses the challenge of task heterogeneity by a learned
19
+ meta-knowledge graph, but also increases the model interpretability. We conduct
20
+ extensive experiments on 2D toy regression and few-shot image classification and
21
+ the results demonstrate the superiority of ARML over state-of-the-art baselines.
22
+
23
+ 1 INTRODUCTION
24
+
25
+ Learning quickly with a few samples is the key characteristic of human intelligence, which remains a
26
+ daunting problem in machine intelligence. The mechanism of learning to learn (a.k.a., meta-learning)
27
+ is widely used to generalize and transfer prior knowledge learned from previous tasks to improve
28
+ the effectiveness of learning on new tasks, which has benefited various applications, ranging from
29
+ computer vision (Kang et al., 2019; Liu et al., 2019) to natural language processing (Gu et al., 2018;
30
+ Lin et al., 2019). Most of existing meta-learning algorithms learn a globally shared meta-learner
31
+ (e.g., parameter initialization (Finn et al., 2017), meta-optimizer (Ravi & Larochelle, 2016), metric
32
+ space (Snell et al., 2017)). However, globally shared meta-learners fail to handle tasks lying in
33
+ different distributions, which is known as task heterogeneity. Task heterogeneity has been regarded as
34
+ one of the most challenging issues in few-shot learning, and thus it is desirable to design meta-learning
35
+ models that effectively optimize each of the heterogeneous tasks.
36
+
37
+ The key challenge to deal with task heterogeneity is how to customize globally shared meta-learner
38
+ by using task-aware information? Recently, a handful of works try to solve the problem by learning
39
+ a task-specific representation for tailoring the transferred knowledge to each task (Oreshkin et al.,
40
+ 2018; Vuorio et al., 2019; Lee & Choi, 2018). However, the success of these methods relies on the
41
+ impaired knowledge generalization among closely correlated tasks (e.g., the tasks sampled from the
42
+ same distribution). Recently, learning the underlying structure among tasks provide a more effective
43
+ way for balancing the customization and generalization. Representatively, Yao et al. propose a
44
+ hierarchically structured meta-learning method to customize the globally shared knowledge to each
45
+ cluster in a hierarchical way (Yao et al., 2019). Nonetheless, the hierarchical clustering structure
46
+ completely relies on the handcrafted design which needs to be tuned carefully and may lack the
47
+ capability to capture complex relationships.
48
+
49
+ Hence, we are motivated to propose a framework to automatically extract underlying relational
50
+ structures from previously learned tasks and leverage those relational structures to facilitate knowledge
51
+ customization on a new task. This inspiration comes from the way of structuring knowledge in
52
+ knowledge bases (i.e., knowledge graphs). In knowledge bases, the underlying relational structures
53
+ across text entities are automatically constructed and applied to a new query to improve the searching
54
+ efficiency. In the meta-learning problem, similarly, we aim at automatically establishing the metaknowledge graph between prior knowledge learned from previous tasks. When a new task arrives,
55
+ it queries the meta-knowledge graph and quickly attends to the most relevant entities (nodes), and
56
+ then takes advantage of the relational knowledge structures between them to boost the learning
57
+ effectiveness with the limited training data.
58
+
59
+
60
+ -----
61
+
62
+ The proposed meta-learning framework is named as Automated Relational Meta-Learning (ARML).
63
+ Specifically, the ARML framework automatically builds the meta-knowledge graph from metatraining tasks to memorize and organize learned knowledge from historical tasks, where each vertex
64
+ represent one type of meta-knowledge (e.g., the common contour between birds and aircrafts). To
65
+ learn the meta-knowledge graph at meta-training time, for each task, we construct a prototype-based
66
+ relational graph for each class, where each vertex represents one prototype. The prototype-based
67
+ relational graph not only captures the underlying relationship behind samples, but alleviates the
68
+ potential effects of abnormal samples. The meta-knowledge graph is then learned by and summarizing
69
+ the information from the corresponding prototype-based relational graphs of meta-training tasks.
70
+ After constructing the meta-knowledge graph, when a new task comes in, the prototype-based
71
+ relational graph of the new task taps into the meta-knowledge graph for acquiring the most relevant
72
+ knowledge, which further enhances the task representation and facilitates its training process.
73
+
74
+ Our major contributions of the proposed ARML are three-fold: (1) it automatically constructs the
75
+ meta-knowledge graph to facilitate learning a new task; (2) it empirically outperforms state-of-the-art
76
+ meta-learning algorithms; (3) the meta-knowledge graph well captures the relationship among tasks
77
+ and improves the interpretability of meta-learning algorithms.
78
+
79
+ 2 RELATED WORK
80
+
81
+ Meta-learning, allowing machines to learn new skills or adapt to new environments rapidly with a
82
+ few training examples, has been demonstrated to be successful in both supervised learning tasks
83
+ (e.g., few-shot image classification) and reinforcement learning settings. There are mainly three
84
+ research lines of meta-learning: (1) black-box amortized methods design black-box meta-learners
85
+ (e.g., neural networks) to infer the model parameters (Ravi & Larochelle, 2016; Andrychowicz et al.,
86
+ 2016; Mishra et al., 2018); (2) gradient-based methods aim to learn an optimized initialization of
87
+ model parameters, which can be adapted to new tasks by a few steps of gradient descent (Finn et al.,
88
+ 2017; 2018; Lee & Choi, 2018); (3) non-parameteric methods combine parameteric meta-learners
89
+ and non-parameteric learners to learn an appropriate distance metric for few-shot classification (Snell
90
+ et al., 2017; Vinyals et al., 2016; Yang et al., 2018; Oreshkin et al., 2018; Yoon et al., 2019).
91
+
92
+ Our work is built upon the gradient-based meta-learning methods. In the line of gradient-based
93
+ meta-learning, most algorithms learn a globally shared meta-learners from all previous tasks (Finn
94
+ et al., 2017; Li et al., 2017; Flennerhag et al., 2019), to improve the effectiveness of learning process
95
+ on new tasks. However, these algorithms typically lack the ability to handle heterogeneous tasks
96
+ (i.e., tasks sample from sufficient different distributions). To tackle this challenge, recent works
97
+ tailor the globally shared initialization to different tasks by leveraging task-specific information (Lee
98
+ & Choi, 2018; Vuorio et al., 2019; Oreshkin et al., 2018) and using probabilistic models (Grant
99
+ et al., 2018; Yoon et al., 2018; Gordon et al., 2019). Recently, HSML customizes the global shared
100
+ initialization with a manually designed hierarchical clustering structure to balance the generalization
101
+ and customization between previous tasks (Yao et al., 2019). However, the hierarchical structure
102
+ may not accurately reflect the real structure since it highly relies on the hand-crafted design. In
103
+ addition, the clustering structure further constricts the complexity of relational structures. However, to
104
+ customize each task, our proposed ARML leverages the most relevant structure from meta-knowledge
105
+ graph which are automatically constructed by previous knowledge. Thus, ARML not only discovers
106
+ more accurate underlying structures to improve the effectiveness of meta-learning algorithms, but
107
+ also the meta-knowledge graph can further enhance the model interpretability.
108
+
109
+ 3 PRELIMINARIES
110
+
111
+ **Few-shot Learning** Considering a task Ti, the goal of few-shot learning is to learn a model with
112
+ a dataset Di = {Di[tr][,][ D]i[ts][}][, where the labeled training set][ D]i[tr] = {x[tr]j _[,][ y]j[tr][|∀][j][ ∈]_ [[1][, N][ tr][]][}][ only has a]
113
+ few samples and Di[ts] [represents the corresponding test set. A learning model (a.k.a., base model)][ f]
114
+ with parameters θ are used to evaluate the effectiveness on Di[ts] [by minimizing the expected empirical]
115
+ loss on Di[tr][, i.e.,][ L][(][D]T[tr]i _[, θ][)][, and obtain the optimal parameters][ θ][i][. For the regression problem, the loss]_
116
+ function is defined based on the mean square error (i.e., (xj _,yj_ )∈Di[tr] 2[) and for the clas-]
117
+
118
+ sification problem, the loss function uses the cross entropy loss (i.e., −[∥][f][P][θ][(]([x]x[j]j[)],y[−]j )[y]∈D[j][∥]i[tr][2] [log][ p][(][y][j][|][x][j][, f][θ][)][).]
119
+
120
+ Usually, optimizing and learning parameter θ for the task[P] _Ti with a few labeled training samples_
121
+ is difficult. To address this limitation, meta-learning provides us a new perspective to improve the
122
+ performance by leveraging knowledge from multiple tasks.
123
+
124
+
125
+ -----
126
+
127
+ **Meta-learning and Model-agnostic Meta-learning** In meta-learning, a sequence of tasks
128
+ _{T1, ..., TI_ _} are sampled from a task-level probability distribution p(T ), where each one is a few-shot_
129
+ learning task. To facilitate the adaption for incoming tasks, the meta-learning algorithm aims to find
130
+ a well-generalized meta-learner on I training tasks at meta-learning phase. At meta-testing phase, the
131
+ optimal meta-learner is applied to adapt the new tasks Tt. In this way, meta-learning algorithms are
132
+ capable of adapting to new tasks efficiently even with a shortage of training data for a new task.
133
+
134
+ Model-agnostic meta-learning (MAML) (Finn et al., 2017), one of the representative algorithms in
135
+ gradient-based meta-learning, regards the meta-learner as the initialization of parameter θ, i.e., θ0,
136
+ and learns a well-generalized initialization θ0[∗] [during the meta-training process. The optimization]
137
+ problem is formulated as (one gradient step as exemplary):
138
+
139
+
140
+ _θ0[∗]_ [:= arg min]
141
+ _θ0_
142
+
143
+
144
+ (fθi _,_ _i_ [) = arg min]
145
+ _L_ _D[ts]_ _θ0_
146
+ _i=1_
147
+
148
+ X
149
+
150
+
151
+ _L(fθ0−α∇θ_ _L(fθ_ _,Ditr_ [)][,][ D]i[ts][)][.] (1)
152
+ _i=1_
153
+
154
+ X
155
+
156
+
157
+ At the meta-testing phase, to obtain the adaptive parameter θt for each new task Tt, we finetune the
158
+ initialization of parameter θ0[∗] [by performing gradient updates a few steps, i.e.,][ f]θt [=][ f]θ0[∗] _t_ [)][.]
159
+
160
+ _[−][α][∇][θ]_ _[L][(][f][θ]_ _[,][D][tr]_
161
+
162
+ 4 METHODOLOGY
163
+
164
+ In this section, we introduce the details of the proposed ARML. To better explain how it works,
165
+ we show its framework in Figure 1. The goal of ARML is to facilitate the learning process of new
166
+ tasks by leveraging transferable knowledge learned from historical tasks. To achieve this goal, we
167
+ introduce a meta-knowledge graph, which is automatically constructed at the meta-training time, to
168
+ organize and memorize historical learned knowledge. Given a task, which is built as a prototypebased relational structure, it taps into the meta-knowledge graph to acquire relevant knowledge for
169
+ enhancing its own representation. The enhanced prototype representation further aggregate and
170
+ incorporate with meta-learner for fast and effective adaptions by utilizing a modulating function. In
171
+ the following subsections, we elaborate three key components: prototype-based sample structuring,
172
+ automated meta-knowledge graph construction and utilization, and task-specific knowledge fusion
173
+ and adaptation, respectively.
174
+
175
+ **Propagation**
176
+
177
+ **Prototype-based** **Meta-knowledge**
178
+
179
+ **Prototypes** **Relational** **Graph )**
180
+
181
+ **Structure ℛ#**
182
+
183
+ +#(,
184
+
185
+
186
+
187
+ … !"
188
+
189
+ **Aggregator**
190
+
191
+ ℒ( **Modulation**
192
+
193
+ **Aggregator**
194
+
195
+ +#(- ℒ' ∇%ℒ !"#
196
+
197
+ !#
198
+
199
+
200
+ Figure 1: The framework of ARML. For each task _i, ARML first builds a prototype-based relational_
201
+ _T_
202
+ structure Ri by mapping the training samples Di[tr] [into prototypes, with each prototype represents]
203
+ one class. Then, Ri interacts with the meta-knowledge graph G to acquire the most relevant historical
204
+ knowledge by information propagation. Finally, the task-specific modulation tailors the globally
205
+ shared initialization θ0 by aggregating of raw prototypes and enriched prototypes, which absorbs
206
+ relevant historical information from the meta-knowledge graph.
207
+
208
+ 4.1 PROTOTYPE-BASED SAMPLE STRUCTURING
209
+
210
+ Given a task which involves either classifications or regressions regarding a set of samples, we first
211
+ investigate the relationships among these samples. Such relationship is represented by a graph, called
212
+ prototype-based relational graph in this work, where the vertices in the graph denote the prototypes
213
+ of different classes while the edges and the corresponding edge weights are created based on the
214
+
215
+
216
+ -----
217
+
218
+ similarities between prototypes. Constructing the relational graph based on prototypes instead of raw
219
+ samples allows us to alleviate the issue raised by abnormal samples. As the abnormal samples, which
220
+ locate far away from normal samples, could pose significant concerns especially when only a limited
221
+ number of samples are available for training. Specifically, for classification problem, the prototype,
222
+ denoted by c[k]i
223
+
224
+ _[∈]_ [R][d][, is defined as:] _N_ _[tr]_
225
+
226
+
227
+ **c[k]i** [=]
228
+
229
+
230
+ _E(xj),_ (2)
231
+ _j=1_
232
+
233
+ X
234
+
235
+
236
+ _Nk[tr]_
237
+
238
+
239
+ where Nk[tr] [denotes the number of samples in class][ k][.][ E][ is an embedding function, which projects]
240
+ **xj into a hidden space where samples from the same class are located closer to each other while**
241
+ samples from different classes stay apart. For regression problem, it is not straightforward to construct
242
+ the prototypes explicitly based on class information. Therefore, we cluster samples by learning an
243
+ assignment matrix Pi R[K][×][N] _[tr]_ . Specifically, we formulate the process as:
244
+ _∈_
245
+
246
+ **Pi = Softmax(WpE** [T](X) + bp), c[k]i [=][ P]i[[][k][]][F] [(][X][)][,] (3)
247
+
248
+ where Pi[k] represents the k-th row of Pi. Thus, training samples are clustered to K clusters, which
249
+ serve as the representation of prototypes.
250
+
251
+ After calculating all prototype representations **c[k]i**
252
+ _{_ _[|∀][k][ ∈]_ [[1][, K][]][}][, which serve as the vertices in the the]
253
+ prototype-based relational graph Ri, we further define the edges and the corresponding edge weights.
254
+ The edge weight ARi (c[j]i _[,][ c]i[m][)][ between two prototypes][ c]i[j]_ [and][ c]i[m] [is gauged by the the similarity]
255
+ between them. Formally:
256
+
257
+ _ARi_ (c[j]i _[,][ c]i[m][) =][ σ][(][W]r[(][|][c][j]i_ _i_ _r[) +][ b]r[)][,]_ (4)
258
+
259
+ _[−]_ **[c][m][|][/γ]**
260
+
261
+ where Wr and br represents learnable parameters, γr is a scalar and σ is the Sigmoid function, which
262
+ normalizes the weight between 0 and 1. For simplicity, we denote the prototype-based relational graph
263
+ as Ri = (CRi _, ARi_ ), where CRi = {c[j]i _[|∀][j][ ∈]_ [[1][, K][]][} ∈] [R][K][×][d][ represent a set of vertices, with each]
264
+ one corresponds to the prototype from a class, while ARi = {|ARi (c[j]i _[,][ c]i[m][)][|∀][j, m][ ∈]_ [[1][, K][]][} ∈] [R][K][×][K]
265
+ gives the adjacency matrix, which indicates the proximity between prototypes.
266
+
267
+ 4.2 AUTOMATED META-KNOWLEDGE GRAPH CONSTRUCTION AND UTILIZATION
268
+
269
+ In this section, we first discuss how to organize and distill knowledge from historical learning process
270
+ and then expound how to leverage such knowledge to benefit the training of new tasks. To organize
271
+ and distill knowledge from historical learning process, we construct and maintain a meta-knowledge
272
+ graph. The vertices represent different types of meta-knowledge (e.g., the common contour between
273
+ aircrafts and birds) and the edges are automatically constructed and reflect the relationship between
274
+ meta-knowledge. When serving a new task, we refer to the meta-knowledge, which allows us to
275
+ efficiently and automatically identify relational knowledge from previous tasks. In this way, the
276
+ training of a new task can benefit from related training experience and get optimized much faster
277
+ than otherwise possible. In this paper, the meta-knowledge graph is automatically constructed at the
278
+ meta-training phase. The details of the construction are elaborated as follows:
279
+
280
+ Assuming the representation of an vertex g is given by h[g] _∈_ R[d], we define the meta-knowledge
281
+ graph as G = (HG, AG), where HG = {h[j]|∀j ∈ [1, G]} ∈ R[G][×][d] and AG = {AG(h[j], h[m])|∀j, m ∈
282
+
283
+ [1, G]} ∈ R[G][×][G] denote the vertex feature matrix and vertex adjacency matrix, respectively. To better
284
+ explain the construction of the meta-knowledge graph, we first discuss the vertex representation H .
285
+ _G_
286
+ During meta-training, tasks arrive one after another in a sequence and their corresponding vertices
287
+ representations are expected to be updated dynamically in a timely manner. Therefore, the vertex
288
+ representation of meta-knowledge graph are defined to get parameterized and learned at the training
289
+ time. Moreover, to encourage the diversity of meta-knowledge encoded in the meta-knowledge graph,
290
+ the vertex representations are randomly initialized. Analogous to the definition of weight in the
291
+ prototype-based relational graph Ri in equation 4, the weight between a pair of vertices j and m is
292
+ constructed as:
293
+ _A_ (h[j], h[m]) = σ(Wo( **h[j]** **h[m]** _/γo) + bo),_ (5)
294
+ _G_ _|_ _−_ _|_
295
+ where Wo and bo represent learnable parameters and γo is a scalar.
296
+
297
+ To enhance the learning of new tasks with involvement of historical knowledge, we query the
298
+ prototype-based relational graph in the meta-knowledge graph to obtain the relevant knowledge in
299
+ history. The ideal query mechanism is expected to optimize both graph representations simultaneously
300
+
301
+
302
+ -----
303
+
304
+ at the meta-training time, with the training of one graph facilitating the training of the other. In light
305
+ of this, we construct a super-graph Si by connecting the prototype-based relational graph Ri with the
306
+ meta-knowledge graph G for each task Ti. The union of the vertices in Ri and G contributes to the
307
+ vertices in the super-graph. The edges in Ri and G are also reserved in the super-graph. We connect
308
+ _Ri with G by creating links between the prototype-based relational graph with the meta-knowledge_
309
+ graph. The link between prototype c[j]i [in prototype-based relational graph and vertex][ h][m][ in meta-]
310
+ knowledge graph is weighted by the similarity between them. More precisely, for each prototype c[j]i [,]
311
+ the link weight AS (c[j]i _[,][ h][m][)][ is calculated by applying softmax over Euclidean distances between][ c][j]i_
312
+ and {h[m]|∀m ∈ [1, G]} as follows:
313
+
314
+ _AS_ (c[j]i _[,][ h][k][) =]_ _Kexp(−∥(c[j]i_ _[−]_ **[h][k][)][/γ][s][∥]2[2][/][2)]** _,_ (6)
315
+ _k[′]_ =1 [exp(][−∥][(][c]i[j] _[−]_ **[h][k][′][ )][/γ][s][∥]2[2][/][2)]**
316
+
317
+ where γs is a scaling factor. We denote the intra-adjacent matrix asP **AS = {AS** (c[j]i _[,][ h][m][)][|∀][j][ ∈]_
318
+
319
+ [1, K], m ∈ [1, G]} ∈ R[K][×][G]. Thus, for task Ti, the adjacent matrix and feature matrix of super-graph
320
+ _i = (Ai, Hi) is defined as Ai = (A_ _i_ _, A_ ; A[T] [= (][C][R]i [;][ H][G][)][ ∈]
321
+ _S_ _R_ _S_ _S_ _[,][ A][G][)][ ∈]_ [R][(][K][+][G][)][×][(][K][+][G][)][ and][ H][i]
322
+ R[(][K][+][G][)][×][d], respectively.
323
+
324
+ After constructing the super-graph Si, we are able to propagate the most relevant knowledge from
325
+ meta-knowledge graph G to the prototype-based relational graph Ri by introducing a Graph Neural
326
+ Networks (GNN). In this work, following the “message-passing” framework (Gilmer et al., 2017),
327
+ the GNN is formulated as:
328
+ **Hi[(][l][+1)]** = MP(Ai, H[(]i[l][)][;][ W][(][l][)][)][,] (7)
329
+ where MP(·) is the message passing function and has several possible implementations (Hamilton
330
+ et al., 2017; Kipf & Welling, 2017; Velickoviˇ c et al., 2018),´ **H[(]i[l][)]** is the vertex embedding after l
331
+ layers of GNN and W[(][l][)] is a learnable weight matrix of layer l. The input H[(0)]i = Hi. After stacking
332
+ _L GNN layers, we get the information-propagated feature representation for the prototype-based_
333
+ relational graph Ri as the top-K rows of Hi[(][L][)], which is denoted as **C[ˆ]** _Ri = {cˆ[j]i_ _[|][j][ ∈]_ [[1][, K][]][}][.]
334
+
335
+ 4.3 TASK-SPECIFIC KNOWLEDGE FUSION AND ADAPTATION
336
+
337
+ After propagating information form meta-knowledge graph to prototype-based relational graph, in
338
+ this section, we discuss how to learn a well-generalized meta-learner for fast and effective adaptions
339
+ to new tasks with limited training data. To tackle the challenge of task heterogeneity, in this
340
+ paper, we incorporate task-specific information to customize the globally shared meta-learner (e.g.,
341
+ initialization here) by leveraging a modulating function, which has been proven to be effective to
342
+ provide customized initialization in previous studies (Wang et al., 2019; Vuorio et al., 2019).
343
+
344
+ The modulating function relies on well-discriminated task representations, while it is difficult to learn
345
+ all representations by merely utilizing the loss signal derived from the test set Di[ts][. To encourage such]
346
+ stability, we introduce two reconstructions by utilizing two auto-encoders. There are two collections
347
+ of parameters, i.e, CRi and **C[ˆ]** _Ri, which contribute the most to the creation of the task-specific_
348
+ meta-learner. CRi express the raw prototype information without tapping into the meta-knowledge
349
+ graph, while **C[ˆ]** _Ri give the prototype representations after absorbing the relevant knowledge from the_
350
+ meta-knowledge graph. Therefore, the two reconstructions are built on CRi and **C[ˆ]** _Ri_ . To reconstruct
351
+ **CRi**, an aggregator AG[q](·) (e.g., recurrent network, fully connected layers) is involved to encode CRi
352
+ into a dense representation, which is further fed into a decoder AG[q]dec[(][·][)][ to achieve reconstructions.]
353
+ Then, the corresponded task representation qi of CRi is summarized by applying a mean pooling
354
+ operator over prototypes on the encoded dense representation. Formally,
355
+
356
+ _N_ _[tr]_
357
+
358
+
359
+ **qi = MeanPool(AG[q](CRi** )) =
360
+
361
+
362
+ (AG[q](c[j]i [))][,][ L][q][ =][ ∥][C][R]i _dec[(AG][q][(][C][R]i_ [))][∥]F[2] (8)
363
+ _j=1_ _[−]_ [AG][q]
364
+
365
+ X
366
+
367
+
368
+ _N_ _[tr]_
369
+
370
+
371
+ Similarly, we reconstruct **C[ˆ]** _Ri and get the corresponded task representation ti as follows:_
372
+
373
+ _N_ _[tr]_
374
+
375
+
376
+ **ti = MeanPool(AG[t]( C[ˆ]** _Ri_ )) =
377
+
378
+
379
+ _j=1(AG[t](ˆc[j]i_ [))][,][ L][t][ =][ ∥]C[ˆ] _Ri −_ AG[t]dec[(AG][t][( ˆ]CRi ))∥F[2] (9)
380
+
381
+ X
382
+
383
+
384
+ _N_ _[tr]_
385
+
386
+
387
+ The reconstruction errors in Equations 8 and 9 pose an extra constraint to enhance the training
388
+ stability, leading to improvement of task representation learning.
389
+
390
+
391
+ -----
392
+
393
+ **Algorithm 1 Meta-Training Process of ARML**
394
+
395
+ **Require: p(T ): distribution over tasks; K: Number of vertices in meta-knowledge graph; α: stepsize**
396
+ for gradient descent of each task (i.e., inner loop stepsize); β: stepsize for meta-optimization (i.e.,
397
+ outer loop stepsize); µ1, µ2: balancing factors in loss function
398
+
399
+ 1: Randomly initialize all learnable parameters Φ
400
+ 2: while not done do
401
+ 3: Sample a batch of tasks {Ti|i ∈ [1, I]} from p(T )
402
+
403
+ 4: **for all Ti do**
404
+
405
+ 5: Sample training set Di[tr] [and testing set][ D]i[ts]
406
+
407
+ 6: Construct the prototype-based relational graph Ri by computing prototype in equation 2
408
+ and weight in equation 4
409
+
410
+ 7: Compute the similarity between each prototype and meta-knowledge vertex in equation 6
411
+ and construct the super-graph Si
412
+
413
+ 8: Apply GNN on super-graph Si and get the information-propagated representation **C[ˆ]** _Ri_
414
+
415
+ 9: Aggregate CRi in equation 8 and **C[ˆ]** _Ri in equation 9 to get the representations qi, ti and_
416
+ reconstruction loss Lq, Lt
417
+
418
+ 10: Compute the task-specific initialization θ0i in equation 10 and update parameters θi =
419
+ _θ0i −_ _α∇θL(fθ, Di[tr][)]_
420
+
421
+ 11: **end for**
422
+
423
+ 12: Update Φ Φ _β_ Φ _Ii=1_ _i_ _[,][ D]i[ts][) +][ µ][i][L][t]_ [+][ µ][2][L][q]
424
+
425
+ 13: end while _←_ _−_ _∇_ _[L][(][f][θ]_
426
+
427
+ P
428
+
429
+
430
+ After getting the task representation qi and ti, the modulating function is then used to tailor the
431
+ task-specific information to the globally shared initialization θ0, which is formulated as:
432
+
433
+ _θ0i = σ(Wg(ti ⊕_ **qi) + bg) ◦** _θ0,_ (10)
434
+
435
+ where Wg and bg is learnable parameters of a fully connected layer. Note that we adopt the Sigmoid
436
+ gating as exemplary and more discussion about different modulating functions can be found in
437
+ ablation studies of Section 5.
438
+
439
+ For each task Ti, we perform the gradient descent process from θ0i and reach its optimal parameter θi.
440
+ Combining the reconstruction loss Lt and Lq with the meta-learning loss defined in equation 1, the
441
+ overall objective function of ARML is:
442
+
443
+ _I_
444
+
445
+ minΦ Φ Φ _L(fθ0−α∇θ_ _L(fθ_ _,Ditr_ [)][,][ D]i[ts][) +][ µ]1[L]t [+][ µ]2[L]q[,] (11)
446
+
447
+ _[L][all][ = min]_ _[L][ +][ µ][1][L][t][ +][ µ][2][L][q][ = min]_ _i=1_
448
+
449
+ X
450
+
451
+ where µ1 and µ2 are introduced to balance the importance of these three items. Φ represents all
452
+ learnable parameters. The algorithm of meta-training process of ARML is shown in Alg. 2. The
453
+ details of the meta-testing process of ARML are available in Appendix A.
454
+
455
+ 5 EXPERIMENTS
456
+
457
+ In this section, we conduct extensive experiments to demonstrate the effectiveness of the ARML on
458
+ 2D regression and few-shot classification with the goal of answering the following questions: (1) Can
459
+ ARML outperform other meta-learning methods?; (2) Can our proposed components improve the
460
+ learning performance?; (3) Can ARML framework improve the model interpretability by discovering
461
+ reasonable meta-knowledge graph?
462
+
463
+ 5.1 EXPERIMENTAL SETTINGS
464
+
465
+ **Methods for Comparison** We compare our proposed ARML with two types of baselines: gradientbased meta-learning algorithms and non-parameteric meta-learning algorithms.
466
+
467
+ _For gradient-based meta-learning methods: both globally shared methods (MAML (Finn et al.,_
468
+ 2017), Meta-SGD (Li et al., 2017)) and task-specific methods (MT-Net (Lee & Choi, 2018), MUMOMAML (Vuorio et al., 2019), HSML (Yao et al., 2019)) are considered for comparison.
469
+
470
+ _For non-parametric meta-learning methods: we select globally shared method Prototypical Network_
471
+ (ProtoNet) (Snell et al., 2017) and task-specific method TADAM (Oreshkin et al., 2018) as baselines.
472
+ Note that, following the traditional settings, non-parametric baselines are only used in few-shot
473
+ classification problem. The detailed implementations of baselines are discussed in Appendix B.3.
474
+
475
+
476
+ -----
477
+
478
+ **Hyperparameter Settings** For the aggregated function in autoencoder structure (AG[t], AG[t]dec
479
+ AG[q], AG[q]dec[), we use the GRU as the encoder and decoder in this autoencoder framework. We]
480
+ adopt one layer GCN (Kipf & Welling, 2017) with tanh activation as the implementation of GNN
481
+ in equation 7. For the modulation network, we try both sigmoid, tanh Film modulation and find that
482
+ sigmoid modulation performs better. Thus, in the future experiment, we use the sigmoid modulation as
483
+ modulating function. More detailed discussion about experiment settings are presented in Appendix B.
484
+
485
+ 5.2 2D REGRESSION
486
+
487
+ **Dataset Description.** In 2D regression problem, we adopt the similar regression problem settings
488
+ as (Finn et al., 2018; Vuorio et al., 2019; Yao et al., 2019; Rusu et al., 2019), which includes several
489
+ families of functions. In this paper, to model more complex relational structures, we design a 2D
490
+ regression problem rather than traditional 1D regression. Input x ∼ _U_ [0.0, 5.0] and y ∼ _U_ [0.0, 5.0]
491
+ are sampled randomly and random Gaussian noisy with standard deviation 0.3 is added to the
492
+ output. Furthermore, six underlying functions are selected, including (1) Sinusoids: z(x, y) =
493
+ _assin(wsx + bs), where as ∼_ _U_ [0.1, 5.0], bs ∼ _U_ [0, 2π] ws ∼ _U_ [0.8, 1.2]; (2) Line: z(x, y) = alx + bl,
494
+ where al ∼ _U_ [−3.0, 3.0], bl ∼ _U_ [−3.0, 3.0]; (3) Quadratic: z(x, y) = aqx[2] + bqx + cq, where aq ∼
495
+ _U_ [−0.2, 0.2], bq ∼ _U_ [−2.0, 2.0], cq ∼ _U_ [−3.0, 3.0]; (4) Cubic: z(x, y) = acx[3] + bcx[2] + ccx + dc,
496
+ where ac ∼ _U_ [−0.1, 0.1], bc ∼ _U_ [−0.2, 0.2], cc ∼ _U_ [−2.0, 2.0], dc ∼ _U_ [−3.0, 3.0]; (5) Quadratic
497
+ _Surface: z(x, y) = aqsx[2]_ + bqsy[2], where aqs ∼ _U_ [−1.0, 1.0], bqs ∼ _U_ [−1.0, 1.0]; (6) Ripple: z(x, y) =
498
+ _sin(−ar(x[2]_ + y[2])) + br, where ar ∼ _U_ [−0.2, 0.2], br ∼ _U_ [−3.0, 3.0]. Note that, function 1-4 are
499
+ located in the subspace of y = 1. Follow (Finn et al., 2017), we use two fully connected layers with
500
+ 40 neurons as the base model. The number of vertices of meta-knowledge graph is set as 6.
501
+
502
+ **Results and Analysis.** In Figure 2, we summarize the interpretation of meta-knowledge graph
503
+ (see top figure) and the the qualitative results (see bottom table) of 10-shot 2D regression. In the
504
+ bottom table, we can observe that ARML achieves the best performance as compared to competitive
505
+ gradient-based meta-learning methods, i.e., globally shared models and task-specific models. This
506
+ finding demonstrates that the meta-knowledge graph is necessary to model and capture task-specific
507
+ information. The superior performance can also be interpreted in the top figure. In the left, we
508
+ show the heatmap between prototypes and meta-knowledge vertices (deeper color means higher
509
+ similarity). We can see that sinusoids and line activate V1 and V4, which may represent curve and
510
+ line, respectively. V1 and V4 also contribute to quadratic and quadratic surface, which also show
511
+ the similarity between these two families of functions. V3 is activated in P0 of all functions and the
512
+ quadratic surface and ripple further activate V1 in P0, which may show the different between 2D
513
+ functions and 3D functions (sinusoid, line, quadratic and cubic lie in the subspace). Specifically,
514
+ in the right figure, we illustrate the meta-knowledge graph, where we set a threshold to filter the
515
+ link with low similarity score and show the rest. We can see that V3 is the most popular vertice and
516
+
517
+ |Model|MAML Meta-SGD MT-Net MUMOMAML HSML ARML|
518
+ |---|---|
519
+
520
+
521
+ |10-shot|2.292 ± 0.163 2.908 ± 0.229 1.757 ± 0.120 0.523 ± 0.036 0.494 ± 0.038 0.438 ± 0.029|
522
+ |---|---|
523
+
524
+
525
+ connected with V1, V5 (represent curve) and V4 (represent line). V1 is further connected with V5,
526
+ demonstrating the similarity of curve representation.
527
+
528
+ V1
529
+
530
+ V2
531
+
532
+ Sinusoids Line
533
+
534
+ V0 V3
535
+
536
+ Quadratic Cubic
537
+
538
+ V5 V4
539
+
540
+ Quadratic Surface Ripple
541
+
542
+ Model MAML Meta-SGD MT-Net MUMOMAML HSML ARML
543
+
544
+ 10-shot 2.292 0.163 2.908 0.229 1.757 0.120 0.523 0.036 0.494 0.038 **0.438** **0.029**
545
+
546
+
547
+ Figure 2: In the top figure, we show the interpretation of meta-knowledge graph. The left heatmap
548
+ shows the similarity between prototypes (P0, P1) and meta-knowledge vertices (V0-V5). The right
549
+ part show the meta-knowledge graph. In the bottom table, we show the overall performance (mean
550
+ square error with 95% confidence) of 10-shot 2D regression.
551
+
552
+
553
+ -----
554
+
555
+ 5.3 FEW-SHOT CLASSIFICATION
556
+
557
+ **Dataset Description and Settings** In the few-shot classification problem, we first use the benchmark proposed in (Yao et al., 2019), where four fine-grained image classification datasets are included
558
+ (i.e., CUB-200-2011 (Bird), Describable Textures Dataset (Texture), FGVC of Aircraft (Aircraft),
559
+ and FGVCx-Fungi (Fungi)). For each few-shot classification task, it samples classes from one of four
560
+ datasets. In this paper, we call this dataset as Plain-Multi and each fine-grained dataset as subdataset.
561
+
562
+ Then, to demonstrate the effectiveness of our proposed model for handling more complex underlying
563
+ structures, in this paper, we increase the difficulty of few-shot classification problem by introducing
564
+ two image filters: blur filter and pencil filter. Similar as (Jerfel et al., 2019), for each image in PlainMulti, one artistic filters are applied to simulate a changing distribution of few-shot classification
565
+ tasks. After applying the filters, the total number of subdatasets is 12 and each tasks is sampled from
566
+ one of them. This data is named as Art-Multi. More detailed descriptions of the effect of different
567
+ filters is discussed in Appendix C.
568
+
569
+ Following the traditional meta-learning settings, all datasets are divided into meta-training, metavalidation and meta-testing classes. The traditional N-way K-shot settings are used to split training and
570
+ test set for each task. We adopt the standard four-block convolutional layers as the base learner (Finn
571
+ et al., 2017; Snell et al., 2017). The number of vertices of meta-knowledge graph for Plain-Multi
572
+ and Art-Multi datasets are set as 4 and 8, respectively. Additionally, for the miniImagenet, similar
573
+ as (Finn et al., 2018), which tasks are constructed from a single domain and do not have heterogeneity,
574
+ we compare our proposed ARML with other baselines and present the results in Appendix D.
575
+
576
+ 5.3.1 PERFORMANCE VALIDATION
577
+
578
+ **Overall Qualitative Analyses** Experimental results for Plain-Multi and Art-Multi are shown in
579
+ Table 1 and Table 2, respectively. For each dataset, the performance accuracy with 95% confidence
580
+ interval are reported. Note that, due to the space limitation, in Art-Multi dataset, we only show
581
+ the average value of each filter and the full results table are shown in Table 9 of Appendix E. In
582
+ these two tables, first, we can observe that task-specific models (MT-Net, MUMOMAML, HSML,
583
+ TADAM) significantly outperforms globally shared models (MAML, Meta-SGD, ProtoNet) in both
584
+ gradient-based and non-parametric meta-learning research lines. Second, compared ARML with
585
+ other task-specific gradient-based meta-learning methods, the better performance confirms that
586
+ ARML can model and extract task-specific information more accurately by leveraging the constructed
587
+ meta-knowledge graph. Especially, the performance gap between the ARML and HSML verifies the
588
+ benefits of relational structure compared with isolated clustering structure. Finally, as a gradient-based
589
+ meta-learning algorithm, ARML can also outperform ProtoNet and TADAM, two representative
590
+ non-parametric meta-learning algorithms.
591
+
592
+ Table 1: Overall few-shot classification results (accuracy ± 95% confidence) on Plain-Multi dataset.
593
+
594
+ |Settings|Algorithms|Data: Bird Data: Texture Data: Aircraft Data: Fungi|
595
+ |---|---|---|
596
+
597
+ |MAML 53.94 ± 1.45% 31.66 ± 1.31% 51.37 ± 1.38% 42.12 ± 1.36% MetaSGD 55.58 ± 1.43% 32.38 ± 1.32% 52.99 ± 1.36% 41.74 ± 1.34% MT-Net 58.72 ± 1.43% 32.80 ± 1.35% 47.72 ± 1.46% 43.11 ± 1.42% 5-way MUMOMAML 56.82 ± 1.49% 33.81 ± 1.36% 53.14 ± 1.39% 42.22 ± 1.40% 1-shot HSML 60.98 ± 1.50% 35.01 ± 1.36% 57.38 ± 1.40% 44.02 ± 1.39% ProtoNet 54.11 ± 1.38% 32.52 ± 1.28% 50.63 ± 1.35% 41.05 ± 1.37% TADAM 56.58 ± 1.34% 33.34 ± 1.27% 53.24 ± 1.33% 43.06 ± 1.33% ARML 62.33 ± 1.47% 35.65 ± 1.40% 58.56 ± 1.41% 44.82 ± 1.38%|MAML MetaSGD MT-Net MUMOMAML HSML|53.94 ± 1.45% 31.66 ± 1.31% 51.37 ± 1.38% 42.12 ± 1.36% 55.58 ± 1.43% 32.38 ± 1.32% 52.99 ± 1.36% 41.74 ± 1.34% 58.72 ± 1.43% 32.80 ± 1.35% 47.72 ± 1.46% 43.11 ± 1.42% 56.82 ± 1.49% 33.81 ± 1.36% 53.14 ± 1.39% 42.22 ± 1.40% 60.98 ± 1.50% 35.01 ± 1.36% 57.38 ± 1.40% 44.02 ± 1.39%|
598
+ |---|---|---|
599
+ ||ProtoNet TADAM|54.11 ± 1.38% 32.52 ± 1.28% 50.63 ± 1.35% 41.05 ± 1.37% 56.58 ± 1.34% 33.34 ± 1.27% 53.24 ± 1.33% 43.06 ± 1.33%|
600
+ ||ARML|62.33 ± 1.47% 35.65 ± 1.40% 58.56 ± 1.41% 44.82 ± 1.38%|
601
+
602
+ |ARML 62.33 ± 1.47% 35.65 ± 1.40% 58.56 ± 1.41% 44.82 ± 1.38%|ARML|62.33 ± 1.47% 35.65 ± 1.40% 58.56 ± 1.41% 44.82 ± 1.38%|
603
+ |---|---|---|
604
+ |MAML 68.52 ± 0.79% 44.56 ± 0.68% 66.18 ± 0.71% 51.85 ± 0.85% MetaSGD 67.87 ± 0.74% 45.49 ± 0.68% 66.84 ± 0.70% 52.51 ± 0.81% MT-Net 69.22 ± 0.75% 46.57 ± 0.70% 63.03 ± 0.69% 53.49 ± 0.83% 5-way MUMOMAML 70.49 ± 0.76% 45.89 ± 0.69% 67.31 ± 0.68% 53.96 ± 0.82% 5-shot HSML 71.68 ± 0.73% 48.08 ± 0.69% 73.49 ± 0.68% 56.32 ± 0.80% ProtoNet 68.67 ± 0.72% 45.21 ± 0.67% 65.29 ± 0.68% 51.27 ± 0.81% TADAM 69.13 ± 0.75% 45.78 ± 0.65% 69.87 ± 0.66% 53.15 ± 0.82% ARML 73.34 ± 0.70% 49.67 ± 0.67% 74.88 ± 0.64% 57.55 ± 0.82%|MAML MetaSGD MT-Net MUMOMAML HSML|68.52 ± 0.79% 44.56 ± 0.68% 66.18 ± 0.71% 51.85 ± 0.85% 67.87 ± 0.74% 45.49 ± 0.68% 66.84 ± 0.70% 52.51 ± 0.81% 69.22 ± 0.75% 46.57 ± 0.70% 63.03 ± 0.69% 53.49 ± 0.83% 70.49 ± 0.76% 45.89 ± 0.69% 67.31 ± 0.68% 53.96 ± 0.82% 71.68 ± 0.73% 48.08 ± 0.69% 73.49 ± 0.68% 56.32 ± 0.80%|
605
+ ||ProtoNet TADAM|68.67 ± 0.72% 45.21 ± 0.67% 65.29 ± 0.68% 51.27 ± 0.81% 69.13 ± 0.75% 45.78 ± 0.65% 69.87 ± 0.66% 53.15 ± 0.82%|
606
+ ||ARML|73.34 ± 0.70% 49.67 ± 0.67% 74.88 ± 0.64% 57.55 ± 0.82%|
607
+
608
+
609
+ -----
610
+
611
+ Table 2: Overall few-shot classification results (accuracy ± 95% confidence) on Art-Multi dataset.
612
+
613
+ |Settings|Algorithms|Avg. Origninal Avg. Blur Avg. Pencil|
614
+ |---|---|---|
615
+
616
+
617
+ |MAML 42.70 ± 1.35% 40.53 ± 1.38% 36.71 ± 1.37% MetaSGD 44.21 ± 1.38% 42.36 ± 1.39% 37.21 ± 1.39% MT-Net 43.94 ± 1.40% 41.64 ± 1.37% 37.79 ± 1.38% 5-way, 1-shot MUMOMAML 45.63 ± 1.39% 41.59 ± 1.38% 39.24 ± 1.36% HSML 45.68 ± 1.37% 42.62 ± 1.38% 39.78 ± 1.36% Protonet 42.08 ± 1.34% 40.51 ± 1.37% 36.24 ± 1.35% TADAM 44.73 ± 1.33% 42.44 ± 1.35% 39.02 ± 1.34% ARML 47.92 ± 1.34% 44.43 ± 1.34% 41.44 ± 1.34%|MAML MetaSGD MT-Net MUMOMAML HSML|42.70 ± 1.35% 40.53 ± 1.38% 36.71 ± 1.37% 44.21 ± 1.38% 42.36 ± 1.39% 37.21 ± 1.39% 43.94 ± 1.40% 41.64 ± 1.37% 37.79 ± 1.38% 45.63 ± 1.39% 41.59 ± 1.38% 39.24 ± 1.36% 45.68 ± 1.37% 42.62 ± 1.38% 39.78 ± 1.36%|
618
+ |---|---|---|
619
+ ||Protonet TADAM|42.08 ± 1.34% 40.51 ± 1.37% 36.24 ± 1.35% 44.73 ± 1.33% 42.44 ± 1.35% 39.02 ± 1.34%|
620
+ ||ARML|47.92 ± 1.34% 44.43 ± 1.34% 41.44 ± 1.34%|
621
+
622
+
623
+ |ARML 47.92 ± 1.34% 44.43 ± 1.34% 41.44 ± 1.34%|ARML|47.92 ± 1.34% 44.43 ± 1.34% 41.44 ± 1.34%|
624
+ |---|---|---|
625
+ |MAML 58.30 ± 0.74% 55.71 ± 0.74% 49.59 ± 0.73% MetaSGD 57.82 ± 0.72% 55.54 ± 0.73% 50.24 ± 0.72% MT-Net 57.95 ± 0.74% 54.65 ± 0.73% 49.18 ± 0.73% 5-way, 5-shot MUMOMAML 58.60 ± 0.75% 56.29 ± 0.72% 51.15 ± 0.73% HSML 60.63 ± 0.73% 57.91 ± 0.72% 53.93 ± 0.72% Protonet 58.12 ± 0.74% 55.07 ± 0.73% 50.15 ± 0.74% TADAM 60.35 ± 0.72% 58.36 ± 0.73% 53.15 ± 0.74% ARML 61.78 ± 0.74% 58.73 ± 0.75% 55.27 ± 0.73%|MAML MetaSGD MT-Net MUMOMAML HSML|58.30 ± 0.74% 55.71 ± 0.74% 49.59 ± 0.73% 57.82 ± 0.72% 55.54 ± 0.73% 50.24 ± 0.72% 57.95 ± 0.74% 54.65 ± 0.73% 49.18 ± 0.73% 58.60 ± 0.75% 56.29 ± 0.72% 51.15 ± 0.73% 60.63 ± 0.73% 57.91 ± 0.72% 53.93 ± 0.72%|
626
+ ||Protonet TADAM|58.12 ± 0.74% 55.07 ± 0.73% 50.15 ± 0.74% 60.35 ± 0.72% 58.36 ± 0.73% 53.15 ± 0.74%|
627
+ ||ARML|61.78 ± 0.74% 58.73 ± 0.75% 55.27 ± 0.73%|
628
+
629
+
630
+
631
+ **Model Ablation Study** In this section, we perform the ablation study of the proposed ARML to
632
+ demonstrate the effectiveness of each component. The results of ablation study on 5-way, 5-shot
633
+ scenario for Art-Multi dataset are presented in Table 3. In Appendix F, we also show the full results
634
+ for Art-Multi in Table 6 and the ablation study of Plain-Multi in Table 7. Specifically, to show
635
+ the effectiveness of prototype construction, in ablation I, we use the mean pooling aggregation
636
+ of each sample rather than the prototype-based relational graph to interact with meta-knowledge
637
+ graph. In ablation II, we use all samples to construct the sample-level relational graph without
638
+ using the prototype. Compared with ablation I and II, the better performance of ARML shows
639
+ that structuring samples can (1) better handling the underlying relations (2) alleviating the effect of
640
+ potential anomalies by structuring samples as prototypes.
641
+
642
+ In ablation III, we remove the meta-knowledge graph and use the prototype-based relational graph
643
+ structure with aggregator AG[q] as the task representation. The better performance of ARML demonstrates the effectiveness of meta-knowledge graph for capturing the relational structure and facilitating
644
+ the classification performance. We further remove the reconstruction loss and show the results in
645
+ ablation IV and the results demonstrate that the autoencoder structure can benefit the process of
646
+ learning the representation.
647
+
648
+ In ablation VI and VII, we change the modulate function to film (Perez et al., 2018) and tanh,
649
+ respectively. We can see that ARML is not very sensitive to the modulating function, and sigmoid
650
+ function is slightly better than other activation functions in most cases.
651
+
652
+ Table 3: Results (accuracy ± 95% confidence) of Ablation Models (5-way, 5-shot) on Art-Multi.
653
+
654
+ |Ablation Models|Ave. Original Ave. Blur Ave. Pencil|
655
+ |---|---|
656
+
657
+ |I. no prototype-based graph II. no prototype|60.80 ± 0.74% 58.36 ± 0.73% 54.79 ± 0.73% 61.34 ± 0.73% 58.34 ± 0.74% 54.81 ± 0.73%|
658
+ |---|---|
659
+
660
+ |III. no meta-knowledge graph IV. no reconstruction loss|59.99 ± 0.75% 57.79 ± 0.73% 53.68 ± 0.74% 59.07 ± 0.73% 57.20 ± 0.74% 52.45 ± 0.73%|
661
+ |---|---|
662
+
663
+ |V. tanh modulation VI. film modulation|62.34 ± 0.74% 58.58 ± 0.75% 54.01 ± 0.74% 60.06 ± 0.75% 57.47 ± 0.73% 52.06 ± 0.74%|
664
+ |---|---|
665
+
666
+ |ARML|61.78 ± 0.74% 58.73 ± 0.75% 55.27 ± 0.73%|
667
+ |---|---|
668
+
669
+
670
+ 5.3.2 ANALYSIS OF CONSTRUCTED META-KNOWLEDGE GRAPH
671
+
672
+ In this section, we conduct extensive analysis for the constructed meta-knowledge graph, which is
673
+ regarded as the key component in ARML. Due to the space limit, we only present the results on ArtMulti datasets. For Plain-Multi, the analysis with similar observations are discussed in Appendix G.
674
+
675
+
676
+ -----
677
+
678
+ **Performance v.s. Vertice Numbers** We first investigate the impact of vertice numbers in metaknowledge graph. The results are shown in Table 4. From the results, we can notice that the
679
+ performance saturates as the number of vertices researches around 8. One potential reason is that 8
680
+ vertices is enough to capture the potential relations. If we have a larger datasets with more complex
681
+ relations, more vertices may be needed. In addition, if the meta-knowledge graph do not have enough
682
+ vertices, the worse performance suggests that the graph may not be able to capture enough relations
683
+ across tasks.
684
+
685
+ Table 4: Sensitivity analysis with different # of vertices in meta-knowledge graph (5-way, 5-shot).
686
+
687
+ |# of vertices|Ave. Original Ave. Blur Ave. Pencil|
688
+ |---|---|
689
+
690
+
691
+ |4 8 12 16 20|61.18 ± 0.72% 58.13 ± 0.73% 54.88 ± 0.75% 61.78 ± 0.74% 58.73 ± 0.75% 55.27 ± 0.73% 61.66 ± 0.73% 58.61 ± 0.72% 55.07 ± 0.74% 61.75 ± 0.73% 58.67 ± 0.74% 55.26 ± 0.73% 61.91 ± 0.74% 58.92 ± 0.73% 55.24 ± 0.72%|
692
+ |---|---|
693
+
694
+
695
+
696
+ **Model Interpretation Analysis of Meta-Knowledge Graph** We then analyze the learned metaknowledge graph. For each subdataset, we randomly select one task as exemplary. For each task,
697
+ in the left part of Figure 3 we show the similarity heatmap between prototypes and vertices in
698
+ meta-knowledge graph, where deeper color means higher similarity. V0-V8 and P1-P5 denotes
699
+ the different vertices and prototypes, respectively. The meta-knowledge graph is also illustrated
700
+ in the right part. Similar as the graph in 2D regression, we set a threshold to filter links with low
701
+ similarity and illustrate the rest of them. First, We can see that the V1 is mainly activated by bird
702
+ and aircraft (including all filters), which may reflect the shape similarity between bird and aircraft.
703
+ Second, V2, V3, V4 are firstly activated by texture and they form a loop in the meta-knowledge
704
+ graph. Especially, V2 also benefits images with blur and pencil filters. Thus, V2 may represent the
705
+ main texture and facilitate the training process on other subdatasets. The meta-knowledge graph also
706
+ shows the importance of V2 since it is connected with almost all other vertices. Third, when we use
707
+ blur filter, in most cases (bird blur, texture blur, fungi blur), V7 is activated. Thus, V7 may show the
708
+ similarity of images with blur filter. In addition, the connection between V7 and V2 and V3 show that
709
+ classify blur images may depend on the texture information. Fourth, V6 (activated by aircraft mostly)
710
+ connects with V2 and V3, justifying the importance of texture information to classify the aircrafts.
711
+
712
+ V1
713
+
714
+ V2
715
+
716
+ Bird Texture Aircraft Fungi
717
+
718
+ V0 V3
719
+
720
+ Bird Blur Texture Blur Aircraft Blur Fungi Blur V7
721
+
722
+ V4
723
+
724
+ V6
725
+
726
+ V5
727
+
728
+ Bird Pencil Texture Pencil Aircraft Pencil Fungi Pencil
729
+
730
+
731
+ Figure 3: Interpretation of meta-knowledge graph on Art-Multi dataset. For each subdataset, we
732
+ randomly select one task from them. In the left, we show the similarity heatmap between prototypes
733
+ (P0-P5) and meta-knowledge vertices (V0-V7). In the right part, we show the meta-knowledge graph.
734
+
735
+ 6 CONCLUSION
736
+
737
+ In this paper, to improve the effectiveness of meta-learning for handling heterogeneous task, we
738
+ propose a new framework called ARML, which automatically extract relation across tasks and
739
+ construct a meta-knowledge graph. When a new task comes in, it can quickly find the most relevant
740
+ relations through the meta-knowledge graph and use this knowledge to facilitate its training process.
741
+ The experiments demonstrate the effectiveness of our proposed algorithm.
742
+
743
+
744
+ -----
745
+
746
+ REFERENCES
747
+
748
+ Marcin Andrychowicz, Misha Denil, Sergio Gomez, Matthew W Hoffman, David Pfau, Tom Schaul,
749
+ Brendan Shillingford, and Nando De Freitas. Learning to learn by gradient descent by gradient
750
+ descent. In NeurIPS, pp. 3981–3989, 2016.
751
+
752
+ Chelsea Finn and Sergey Levine. Meta-learning and universality: Deep representations and gradient
753
+ descent can approximate any learning algorithm. In ICLR, 2018.
754
+
755
+ Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of
756
+ deep networks. In ICML, pp. 1126–1135, 2017.
757
+
758
+ Chelsea Finn, Kelvin Xu, and Sergey Levine. Probabilistic model-agnostic meta-learning. In NeurIPS,
759
+ 2018.
760
+
761
+ Sebastian Flennerhag, Pablo G Moreno, Neil D Lawrence, and Andreas Damianou. Transferring
762
+ knowledge across learning processes. ICLR, 2019.
763
+
764
+ Justin Gilmer, Samuel S Schoenholz, Patrick F Riley, Oriol Vinyals, and George E Dahl. Neural
765
+ message passing for quantum chemistry. In ICML, pp. 1263–1272. JMLR. org, 2017.
766
+
767
+ Jonathan Gordon, John Bronskill, Matthias Bauer, Sebastian Nowozin, and Richard E Turner. Metalearning probabilistic inference for prediction. In ICLR, 2019.
768
+
769
+ Erin Grant, Chelsea Finn, Sergey Levine, Trevor Darrell, and Thomas Griffiths. Recasting gradientbased meta-learning as hierarchical bayes. In ICLR, 2018.
770
+
771
+ Jiatao Gu, Yong Wang, Yun Chen, Kyunghyun Cho, and Victor OK Li. Meta-learning for low-resource
772
+ neural machine translation. In EMNLP, 2018.
773
+
774
+ Will Hamilton, Zhitao Ying, and Jure Leskovec. Inductive representation learning on large graphs. In
775
+ _NeurIPS, pp. 1024–1034, 2017._
776
+
777
+ Ghassen Jerfel, Erin Grant, Thomas L Griffiths, and Katherine Heller. Reconciling meta-learning and
778
+ continual learning with online mixtures of tasks. NeurIPS, 2019.
779
+
780
+ Bingyi Kang, Zhuang Liu, Xin Wang, Fisher Yu, Jiashi Feng, and Trevor Darrell. Few-shot object
781
+ detection via feature reweighting. In ICCV, 2019.
782
+
783
+ Thomas N Kipf and Max Welling. Semi-supervised classification with graph convolutional networks.
784
+ In ICLR, 2017.
785
+
786
+ Yoonho Lee and Seungjin Choi. Gradient-based meta-learning with learned layerwise metric and
787
+ subspace. In ICML, pp. 2933–2942, 2018.
788
+
789
+ Zhenguo Li, Fengwei Zhou, Fei Chen, and Hang Li. Meta-sgd: Learning to learn quickly for few
790
+ shot learning. arXiv preprint arXiv:1707.09835, 2017.
791
+
792
+ Zhaojiang Lin, Andrea Madotto, Chien-Sheng Wu, and Pascale Fung. Personalizing dialogue agents
793
+ via meta-learning. 2019.
794
+
795
+ Ming-Yu Liu, Xun Huang, Arun Mallya, Tero Karras, Timo Aila, Jaakko Lehtinen, and Jan Kautz.
796
+ Few-shot unsupervised image-to-image translation. arXiv preprint arXiv:1905.01723, 2019.
797
+
798
+ Nikhil Mishra, Mostafa Rohaninejad, Xi Chen, and Pieter Abbeel. A simple neural attentive metalearner. ICLR, 2018.
799
+
800
+ Alex Nichol and John Schulman. Reptile: a scalable metalearning algorithm. arXiv preprint
801
+ _arXiv:1803.02999, 2018._
802
+
803
+ Boris Oreshkin, Pau Rodr´ıguez Lopez, and Alexandre Lacoste. Tadam: Task dependent adaptive´
804
+ metric for improved few-shot learning. In NeurIPS, pp. 721–731, 2018.
805
+
806
+ Ethan Perez, Florian Strub, Harm de Vries, Vincent Dumoulin, and Aaron C. Courville. Film: Visual
807
+ reasoning with a general conditioning layer. In AAAI, 2018.
808
+
809
+
810
+ -----
811
+
812
+ Sachin Ravi and Hugo Larochelle. Optimization as a model for few-shot learning. ICLR, 2016.
813
+
814
+ Andrei A Rusu, Dushyant Rao, Jakub Sygnowski, Oriol Vinyals, Razvan Pascanu, Simon Osindero,
815
+ and Raia Hadsell. Meta-learning with latent embedding optimization. In ICLR, 2019.
816
+
817
+ Jake Snell, Kevin Swersky, and Richard Zemel. Prototypical networks for few-shot learning. In
818
+ _NeurIPS, pp. 4077–4087, 2017._
819
+
820
+ Petar Velickoviˇ c, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Lio, and Yoshua´
821
+ Bengio. Graph attention networks. In ICLR, 2018.
822
+
823
+ Oriol Vinyals, Charles Blundell, Timothy Lillicrap, Daan Wierstra, et al. Matching networks for one
824
+ shot learning. In NeurIPS, pp. 3630–3638, 2016.
825
+
826
+ Risto Vuorio, Shao-Hua Sun, Hexiang Hu, and Joseph J Lim. Toward multimodal model-agnostic
827
+ meta-learning. NeurIPS, 2019.
828
+
829
+ Xin Wang, Fisher Yu, Ruth Wang, Trevor Darrell, and Joseph E Gonzalez. Tafe-net: Task-aware
830
+ feature embeddings for low shot learning. In CVPR, pp. 1831–1840, 2019.
831
+
832
+ Flood Sung Yongxin Yang, Li Zhang, Tao Xiang, Philip HS Torr, and Timothy M Hospedales.
833
+ Learning to compare: Relation network for few-shot learning. In CVPR, 2018.
834
+
835
+ Huaxiu Yao, Ying Wei, Junzhou Huang, and Zhenhui Li. Hierarchically structured meta-learning. In
836
+ _ICML, pp. 7045–7054, 2019._
837
+
838
+ Jaesik Yoon, Taesup Kim, Ousmane Dia, Sungwoong Kim, Yoshua Bengio, and Sungjin Ahn.
839
+ Bayesian model-agnostic meta-learning. In NeurIPS, pp. 7343–7353, 2018.
840
+
841
+ Sung Whan Yoon, Jun Seo, and Jaekyun Moon. Tapnet: Neural network augmented with task-adaptive
842
+ projection for few-shot learning. In ICML, 2019.
843
+
844
+
845
+ -----
846
+
847
+ A ALGORITHM IN META-TESTING PROCESS
848
+
849
+ **Algorithm 2 Meta-Testing Process of ARML**
850
+
851
+ **Require: Training data** _t_ [of a new task][ T][t]
852
+ _D[tr]_
853
+
854
+ 1: Construct the prototype-based relational graph Rt by computing prototype in equation 2 and
855
+ weight in equation 4
856
+
857
+ 2: Compute the similarity between each prototype and meta-knowledge vertice in equation 6 and
858
+ construct the super-graph St
859
+
860
+ 3: Apply GNN on super-graph St and get the updated prototype representation **C[ˆ]** _Rt_
861
+
862
+ 4: Aggregate CRt in equation 8, **C[ˆ]** _Rt in equation 9 and get the representations qt, tt_
863
+
864
+ 5: Compute the task-specific initialization θ0t in equation 10
865
+ 6: Update parameters θt = θ0t − _α∇θL(fθ, Dt[tr][)]_
866
+
867
+
868
+ B HYPERPARAMETERS SETTINGS
869
+
870
+ B.1 2D REGRESSION
871
+
872
+ In 2D regression problem, we set the inner-loop stepsize (i.e., α) and outer-loop stepsize (i.e., β) as
873
+ 0.001 and 0.001, respectively. The embedding function E is set as one layer with 40 neurons. The
874
+ autoencoder aggregator is constructed by the gated recurrent structures. We set the meta-batch size as
875
+ 25 and the inner loop gradient steps as 5.
876
+
877
+ B.2 FEW-SHOT IMAGE CLASSIFICATION
878
+
879
+ In few-shot image classification, for both Plain-Multi and Art-Multi datasets, we set the corresponding
880
+ inner stepsize (i.e., α) as 0.001 and the outer stepsize (i.e., β) as 0.01. For the embedding function E,
881
+ we employ two convolutional layers with 3 × 3 filters. The channel size of these two convolutional
882
+ layers are 32. After convolutional layers, we use two fully connected layers with 384 and 64 neurons
883
+ for each layer. Similar as the hyperparameter settings in 2D regression, the autoencoder aggregator
884
+ is constructed by the gated recurrent structures, i.e., AG[t], AG[t]dec [AG][q][,][ AG]dec[q] [are all GRUs. The]
885
+ meta-batch size is set as 4. For the inner loop, we use 5 gradient steps.
886
+
887
+ B.3 DETAILED BASELINE SETTINGS
888
+
889
+ For the gradient-based baselines (i.e., MAML, MetaSGD, MT-Net, BMAML. MUMOMAML,
890
+ HSML), we use the same inner loop stepsize and outer loop stepsize rate as our ARML. As for
891
+ non-parametric based meta-learning algorithms, both TADAM and Prototypical network, we use the
892
+ same meta-training and meta-testing process as gradient-based models. Additionally, TADAM uses
893
+ the same embedding function E as ARML for fair comparison (i.e., similar expressive ability).
894
+
895
+ C ADDITIONAL DISCUSSION OF DATASETS
896
+
897
+ In this dataset, we use pencil and blur filers to change the task distribution. To investigate the effect
898
+ of pencil and blur filters, we provide one example in Figure 4. We can observe that different filters
899
+ result in different data distributions. All used filter are provided by OpenCV[1].
900
+
901
+ D RESULTS ON MINIIMAGENET
902
+
903
+ For miniimagenet, since it do not have the characteristic of task heterogeneity, we show the results in
904
+ Table 5. In this table, we compare the MiniImagenet dataset with other gradient-based meta-learning
905
+ models (the first four baselines are globally shared models and the next four are task-specific models).
906
+ Similar as (Finn et al., 2018), we also apply the standard 4-block convolutional layers for each
907
+
908
+ 1https://opencv.org/
909
+
910
+
911
+ -----
912
+
913
+ (a) : Plain Image (b) : with blur filter (c) : with pencil filter
914
+
915
+ Figure 4: Effect of different filters.
916
+
917
+ baseline. For MT-Net, we use the reported results in (Yao et al., 2019), which control the model with
918
+ the same expressive power. The results indicate that our proposed ARML can outperform the original
919
+ MAML and achieves comparable performance with task-specific models (e.g., MT-Net, PLATIPUS,
920
+ HSML). Most task-specific models achieve the similar performance on the standard benchmark due
921
+ to the homogeneity between tasks.
922
+
923
+ Table 5: Performance comparison on the 5-way, 1-shot MiniImagenet dataset.
924
+
925
+ |Algorithms|5-way 1-shot Accuracy|
926
+ |---|---|
927
+
928
+ |MAML (Finn et al., 2017) LLAMA (Finn & Levine, 2018) Reptile (Nichol & Schulman, 2018) MetaSGD (Li et al., 2017)|48.70 1.84% ± 49.40 1.83% ± 49.97 0.32% ± 50.47 1.87% ±|
929
+ |---|---|
930
+
931
+ |MT-Net (Lee & Choi, 2018) MUMOMAML (Vuorio et al., 2019) HSML (Yao et al., 2019) PLATIPUS (Finn et al., 2018)|49.75 1.83% ± 49.86 1.85% ± 50.38 1.85% ± 50.13 1.86% ±|
932
+ |---|---|
933
+
934
+ |ARML|50.42 1.73% ±|
935
+ |---|---|
936
+
937
+
938
+ E ADDITIONAL RESULTS OF FEW-SHOT IMAGE CLASSIFICATION
939
+
940
+ E.1 FULL OVERALL RESULTS TABLE OF ART-MULTI DATASET
941
+
942
+ We provide the full results table of Art-Multi Dataset in Table 9. In this table, we can see our proposed
943
+ ARML outperforms almost all baselines in every sub-datasets.
944
+
945
+ F FURTHER INVESTIGATION OF ABLATION STUDY
946
+
947
+ In this section, we first show the full evaluation results of model ablation study on Art-Multi dataset
948
+ in 6. Note that, for the tanh activation (ablation model V), the performance is similar as applying
949
+ the sigmoid activation. On some subdatasets, the results are even better. We choose the sigmoid
950
+ activation for ARML because it achieves overall better performance than the tanh activation on more
951
+ subdatasets. Then, for Plain-Multi dataset, we show the results in 7. The conclusion of ablation study
952
+ in Plain-Multi dataset is similar as the conclusion drawn from the results on Art-Multi dataset. The
953
+ improvement on these two datasets verifies the necessity of the joint framework in ARML.
954
+
955
+ G ADDITIONAL ANALYSIS OF META-KNOWLEDGE GRAPH
956
+
957
+ In this section, we add more interpretation analysis of meta-knowledge graph. First, we show the full
958
+ evaluation results of sensitivity analysis on Art-Multi dataset in Table 8.
959
+
960
+
961
+ -----
962
+
963
+ Table 6: Full evaluation results of model ablation study on Art-Multi dataset. B, T, A, F represent
964
+ bird, texture, aircraft, fungi, respectively. Plain means original image.
965
+
966
+ |Model|B Plain B Blur B Pencil T Plain T Blur T Pencil|
967
+ |---|---|
968
+
969
+
970
+ |I. no prototype-based graph II. no prototype|72.08% 71.06% 66.83% 45.23% 39.97% 41.67% 72.99% 70.92% 67.19% 45.17% 40.05% 41.04%|
971
+ |---|---|
972
+
973
+
974
+ |III. no meta-knowledge graph IV. no reconstruction loss|70.79% 69.53% 64.87% 43.37% 39.86% 41.23% 70.82% 69.87% 65.32% 44.02% 40.18% 40.52%|
975
+ |---|---|
976
+
977
+
978
+ |V. tanh VI. film|72.70% 69.53% 66.85% 45.81% 40.79% 38.64% 71.52% 68.70% 64.23% 43.83% 40.52% 39.49%|
979
+ |---|---|
980
+
981
+
982
+ |Model|A Plain A Blur A Pencil F Plain F Blur F Pencil|
983
+ |---|---|
984
+
985
+
986
+ |I. no prototype-based graph II. no prototype|70.06% 68.02% 60.66% 55.81% 54.39% 50.01% 71.10% 67.59% 61.07% 56.11% 54.82% 49.95%|
987
+ |---|---|
988
+
989
+
990
+ |III. no meta-knowledge graph IV. no reconstruction loss|69.97% 68.03% 59.72% 55.84% 53.72% 48.91% 66.83% 65.73% 55.98% 54.62% 53.02% 48.01%|
991
+ |---|---|
992
+
993
+
994
+ |V. tanh VI. film|73.96% 69.70% 60.75% 56.87% 54.30% 49.82% 69.13% 66.93% 55.59% 55.77% 53.72% 48.92%|
995
+ |---|---|
996
+
997
+
998
+ |ARML|71.89% 68.59% 61.41% 56.83% 54.87% 50.53%|
999
+ |---|---|
1000
+
1001
+
1002
+ ARML **73.05%** **71.31%** **67.14%** 45.32% 40.15% **41.98%**
1003
+
1004
+
1005
+ Table 7: Results of Model Ablation (5-way, 5-shot results) on Plain-Multi dataset.
1006
+
1007
+ |Ablation Models|Bird|Texture|Aircraft|Fungi|
1008
+ |---|---|---|---|---|
1009
+
1010
+ |I. no sample-level graph II. no prototype|71.96 ± 0.72% 72.86 ± 0.74%|48.79 ± 0.67% 49.03 ± 0.69%|74.02 ± 0.65% 74.36 ± 0.65%|56.83 ± 0.80% 57.02 ± 0.81%|
1011
+ |---|---|---|---|---|
1012
+
1013
+ |III. no meta-knowledge graph IV. no reconstruction loss|71.23 ± 0.75% 70.99 ± 0.74%|47.96 ± 0.68% 48.03 ± 0.69%|73.71 ± 0.69% 69.86 ± 0.66%|55.97 ± 0.82% 55.78 ± 0.83%|
1014
+ |---|---|---|---|---|
1015
+
1016
+ |V. tanh VI. film|73.45 ± 0.71% 72.95 ± 0.73%|49.23 ± 0.66% 49.18 ± 0.69%|74.39 ± 0.65% 73.82 ± 0.68%|57.38 ± 0.80% 56.89 ± 0.80%|
1017
+ |---|---|---|---|---|
1018
+
1019
+ |ARML|73.34 ± 0.70%|49.67 ± 0.67%|74.88 ± 0.64%|57.55 ± 0.82%|
1020
+ |---|---|---|---|---|
1021
+
1022
+
1023
+ Then, we analyze the meta-knowledge graph on Plain-Multi dataset by visualizing the learned metaknowledge graph on Plain-Multi dataset (as shown in Figure 5). In this figure, we can see that
1024
+ different subdatasets activate different vertices. Specifically, V2, which is mainly activated by texture,
1025
+ plays a significantly important role in aircraft and fungi. Thus, V2 connects with V3 and V1 in the
1026
+ meta-knowledge graph, which are mainly activated by fungi and aircraft, respectively. In addition,
1027
+ V0 is also activated by aircraft because of the similar contour between aircraft and bird. Furthermore,
1028
+ in meta-knowledge graph, V0 connects with V3, which shows the similarity of environment between
1029
+ bird images and fungi images.
1030
+
1031
+
1032
+ -----
1033
+
1034
+ Bird
1035
+
1036
+
1037
+ Texture
1038
+
1039
+
1040
+ V1
1041
+
1042
+ V2
1043
+
1044
+ V0
1045
+
1046
+ V3
1047
+
1048
+
1049
+ Aircraft Fungi
1050
+
1051
+ Figure 5: Interpretation of meta-knowledge graph on Plain-Multi dataset. For each subdataset, one
1052
+ task is randomly selected from them. In the left figure, we show the similarity heatmap between
1053
+ prototypes (P1-P5) and meta-knowledge vertices (denoted as E1-E4), where deeper color means
1054
+ higher similarity. In the right part, we show the meta-knowledge graph, where a threshold is also set
1055
+ to filter low similarity links.
1056
+
1057
+ Table 8: Full evaluation results of performance v.s. # vertices of meta-knowledge graph on Art-Multi.
1058
+ B, T, A, F represent bird, texture, aircraft, fungi, respectively. Plain means original image.
1059
+
1060
+ |# of Vertices|B Plain B Blur B Pencil T Plain T Blur T Pencil|
1061
+ |---|---|
1062
+
1063
+ |# of Vertices|A Plain A Blur A Pencil F Plain F Blur F Pencil|
1064
+ |---|---|
1065
+
1066
+ |4 8 12 16 20|70.98% 67.36% 60.46% 56.07% 53.77% 50.08% 71.89% 68.59% 61.41% 56.83% 54.87% 50.53% 71.78% 67.26% 60.97% 56.87% 55.14% 50.86% 71.96% 68.55% 61.14% 56.76% 54.54% 49.41% 72.02% 68.29% 60.59% 55.95% 54.53% 50.13%|
1067
+ |---|---|
1068
+
1069
+
1070
+ 4 72.29% 70.36% 67.88% 45.37% 41.05% 41.43%
1071
+ 8 73.05% 71.31% 67.14% 45.32% 40.15% 41.98%
1072
+ 12 73.45% 70.64% 67.41% 44.53% 41.41% 41.05%
1073
+ 16 72.68% 70.18% 68.34% 45.63% 41.43% 42.18%
1074
+ 20 73.41% 71.07% 68.64% 46.26% 41.80% 41.61%
1075
+
1076
+
1077
+ -----
1078
+
1079
+ |55.27% 52.62% 48.58% 30.57% 28.65% 28.39% 45.59% 42.24% 34.52% 39.37% 38.58% 35.38% 55.23% 53.08% 48.18% 29.28% 28.70% 28.38% 51.24% 47.29% 35.98% 41.08% 40.38% 36.30% 56.99% 54.21% 50.25% 32.13% 29.63% 29.23% 43.64% 40.08% 33.73% 43.02% 42.64% 37.96% 57.73% 53.18% 50.96% 31.88% 29.72% 29.90% 49.95% 43.36% 39.61% 42.97% 40.08% 36.52% 58.15% 53.20% 51.09% 32.01% 30.21% 30.17% 49.98% 45.79% 40.87% 42.58% 41.29% 37.01%|53.67% 50.98% 46.66% 31.37% 29.08% 28.48% 45.54% 43.94% 35.49% 37.71% 38.00% 34.36% 54.76% 52.18% 48.85% 32.03% 29.90% 30.82% 50.42% 47.59% 40.17% 41.73% 40.09% 36.27%|59.67% 54.89% 52.97% 32.31% 30.77% 31.51% 51.99% 47.92% 41.93% 44.69% 42.13% 38.36%|
1080
+ |---|---|---|
1081
+ |MAML MetaSGD MT-Net MUMOMAML HSML|ProtoNet TADAM|ARML|
1082
+
1083
+ |71.51% 68.65% 63.93% 42.96% 39.59% 38.87% 64.68% 62.54% 49.20% 54.08% 52.02% 46.39% 71.31% 68.73% 64.33% 41.89% 37.79% 37.91% 64.88% 63.36% 52.31% 53.18% 52.26% 46.43% 71.18% 69.29% 68.28% 43.23% 39.42% 39.20% 63.39% 58.29% 46.12% 54.01% 51.70% 47.02% 71.57% 70.50% 64.57% 44.57% 40.31% 40.07% 63.36% 61.55% 52.17% 54.89% 52.82% 47.79% 71.75% 69.31% 65.62% 44.68% 40.13% 41.33% 70.12% 67.63% 59.40% 55.97% 54.60% 49.40%|70.42% 67.90% 61.82% 44.78% 38.43% 38.40% 65.84% 63.41% 54.08% 51.45% 50.56% 46.33% 70.08% 69.05% 65.45% 44.93% 41.80% 40.18% 70.35% 68.56% 59.09% 56.04% 54.04% 47.85%|73.05% 71.31% 67.14% 45.32% 40.15% 41.98% 71.89% 68.59% 61.41% 56.83% 54.87% 50.53%|
1084
+ |---|---|---|
1085
+ |MAML MetaSGD MT-Net MUMOMAML HSML|ProtoNet TADAM|ARML|
1086
+
1087
+
1088
+ F Pencil
1089
+
1090
+ F Blur
1091
+
1092
+ F Plain
1093
+ A Pencil
1094
+
1095
+ A Blur
1096
+
1097
+ A Plain
1098
+ T Pencil
1099
+
1100
+ T Blur
1101
+
1102
+ T Plain
1103
+ B Pencil
1104
+
1105
+ B Blur
1106
+
1107
+ B Plain
1108
+
1109
+ Algorithms
1110
+ Settings
1111
+
1112
+
1113
+ %
1114
+ **36.38**
1115
+
1116
+ %
1117
+ **13.42**
1118
+
1119
+ %
1120
+ **69.44**
1121
+
1122
+ %
1123
+ **93.41**
1124
+
1125
+ %
1126
+ **92.47**
1127
+
1128
+ %
1129
+ **99.51**
1130
+
1131
+ %
1132
+ **51.31**
1133
+
1134
+ %
1135
+ **77.30**
1136
+
1137
+ %
1138
+ **31.32**
1139
+
1140
+ %
1141
+ **97.52**
1142
+
1143
+ %
1144
+ **89.54**
1145
+
1146
+ %
1147
+ **67.59**
1148
+
1149
+ ARML
1150
+
1151
+
1152
+ %
1153
+ **53.50**
1154
+
1155
+ %
1156
+ **87.54**
1157
+
1158
+ %
1159
+ **83.56**
1160
+
1161
+ %
1162
+ **41.61**
1163
+
1164
+ %
1165
+ **59.68**
1166
+
1167
+ %
1168
+ **89.71**
1169
+
1170
+ %
1171
+ **98.41**
1172
+ 15%.40
1173
+
1174
+ %
1175
+ **32.45**
1176
+
1177
+ %
1178
+ **14.67**
1179
+
1180
+ %
1181
+ **31.71**
1182
+
1183
+ %
1184
+ **05.73**
1185
+
1186
+ ARML
1187
+
1188
+
1189
+ -----
1190
+
ai_scientist/fewshot_examples/2_carpe_diem.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "review": "{\n \"Summary\": \"This paper proposes Recency Bias, an adaptive mini batch selection method for training deep neural networks. To select informative minibatches for training, the proposed method maintains a fixed size sliding window of past model predictions for each data sample. At a given iteration, samples which have highly inconsistent predictions within the sliding window are added to the minibatch. The main contribution of this paper is the introduction of a sliding window to remember past model predictions, as an improvement over the SOTA approach: Active Bias, which maintains a growing window of model predictions. Empirical studies are performed to show the superiority of Recency Bias over two SOTA approaches. Results are shown on the task of (1) image classification from scratch and (2) image classification by fine-tuning pretrained networks.\",\n \"Strengths\": [\n \"The idea of using a sliding window over a growing window in active batch selection is interesting.\",\n \"Overall, the paper is well written. In particular, the Related Work section has a nice flow and puts the proposed method into context. Despite the method having limited novelty (sliding window instead of a growing window), the method has been well motivated by pointing out the limitations in SOTA methods.\",\n \"The results section is well structured. It's nice to see hyperparameter tuning results; and loss convergence graphs in various learning settings for each dataset.\"\n ],\n \"Weaknesses\": [\n \"The key concern about the paper is the lack of rigorous experimentation to study the usefulness of the proposed method. Despite the paper stating that there have been earlier work (Joseph et al, 2019 and Wang et al, 2019) that attempt mini-batch selection, the paper does not compare with them. This is limiting. Further, since the proposed method is not specific to the domain of images, evaluating it on tasks other than image classification, such as text classification for instance, would have helped validate its applicability across domains.\",\n \"Considering the limited results, a deeper analysis of the proposed method would have been nice. The idea of a sliding window over a growing window is a generic one, and there have been many efforts to theoretically analyze active learning over the last two decades. How does the proposed method fit in there? (For e.g., how does the expected model variance change in this setting?) Some form of theoretical/analytical reasoning behind the effectiveness of recency bias (which is missing) would provide greater insights to the community and facilitate further research in this direction.\",\n \"The claim of 20.5% reduction in test error mentioned in the abstract has not been clearly addressed and pointed out in the results section of the paper.\",\n \"The results would have been more complete if results were shown in a setting where just recency bias is used without the use of the selection pressure parameter. In other words, an ablation study on the effect of the selection pressure parameter would have been very useful.\",\n \"The intuition behind the method is described well, however, the proposed method would have been really solidified if it were analysed in the context of a simple machine learning problem (such as logistic regression). As an example, verifying if the chosen minibatch samples are actually close to the decision boundary of a model (even if the model is very simple) would have helped analyze the proposed method well.\"\n ],\n \"Originality\": 3,\n \"Quality\": 2,\n \"Clarity\": 4,\n \"Significance\": 2,\n \"Questions\": [\n \"How important is the warm-up phase to the proposed method? Considering the paper states that this is required to get good estimates of the quantization index of the samples, some ablation studies on reducing/increasing the warm-up phase and showing the results would have been useful to understand this.\",\n \"Fig 4: Why are there sharp dips periodically in all the graphs? What do these correspond to?\",\n \"The results are not conclusively in favor of the proposed method, and only is marginally better than the competitors. Why does online batch perform consistently than the proposed method? There is no discussion of these inferences from the results.\"\n ],\n \"Limitations\": [\n \"The primary concern is about the strength of the experimental results, which showed only a modest benefit on relatively simple datasets.\"\n ],\n \"Ethical Concerns\": false,\n \"Soundness\": 2,\n \"Presentation\": 3,\n \"Contribution\": 2,\n \"Overall\": 4,\n \"Confidence\": 3,\n \"Decision\": \"Reject\"\n}"
3
+ }
ai_scientist/fewshot_examples/2_carpe_diem.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bae395d4e77efb99634b66a0c91616dab4d4af3d34e3e4eb745821e6ce7edcb1
3
+ size 858387
ai_scientist/fewshot_examples/2_carpe_diem.txt ADDED
@@ -0,0 +1,1035 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CARPE DIEM, SEIZE THE SAMPLES UNCERTAIN “AT
2
+ ## THE MOMENT” FOR ADAPTIVE BATCH SELECTION
3
+
4
+ **Anonymous authors**
5
+ Paper under double-blind review
6
+
7
+ ABSTRACT
8
+
9
+ The performance of deep neural networks is significantly affected by how well
10
+ mini-batches are constructed. In this paper, we propose a novel adaptive batch
11
+ selection algorithm called Recency Bias that exploits the uncertain samples
12
+ predicted inconsistently in recent iterations. The historical label predictions of
13
+ each sample are used to evaluate its predictive uncertainty within a sliding window.
14
+ By taking advantage of this design, Recency Bias not only accelerates the training
15
+ step but also achieves a more accurate network. We demonstrate the superiority
16
+ of Recency Bias by extensive evaluation on two independent tasks. Compared with
17
+ existing batch selection methods, the results showed that Recency Bias reduced
18
+ the test error by up to 20.5% in a fixed wall-clock training time. At the same time,
19
+ it improved the training time by up to 59.3% to reach the same test error.
20
+
21
+ 1 INTRODUCTION
22
+
23
+ Stochastic gradient descent (SGD) for randomly selected mini-batch samples is commonly used to
24
+ train deep network netowrks (DNNs). However, many recent studies have pointed out that the performance of DNNs is heavily dependent on how well the mini-batch samples are selected (Shrivastava
25
+ et al., 2016; Chang et al., 2017; Katharopoulos & Fleuret, 2018). In earlier approaches, a sample’s difficulty is employed to identify proper mini-batch samples, and these approaches achieve
26
+ a more accurate and robust network (Han et al., 2018) or expedite the training convergence of
27
+ SGD (Loshchilov & Hutter, 2016). However, the two opposing difficulty-based strategies, i.e., preferring easy samples (Kumar et al., 2010; Han et al., 2018) versus hard samples (Loshchilov & Hutter,
28
+ 2016; Shrivastava et al., 2016), work well in different situations. Thus, for practical reasons to cover
29
+ more diverse situations, recent approaches begin to exploit a sample’s uncertainty that indicates the
30
+ consistency of previous predictions (Chang et al., 2017; Song et al., 2019).
31
+
32
+ An important question here is how to evaluate the sample’s uncertainty based on its historical
33
+ predictions during the training process. Intuitively, because a series of historical predictions can
34
+ be seen as a series of data indexed in chronological order, the uncertainty can be measured based on
35
+ _two forms of handling time-series observations: (i) a growing window (Figure 1(a)) that consistently_
36
+ increases the size of a window to use all available observations and (ii) a sliding window (Figure 1(b))
37
+ that maintains a window of a fixed size on the most recent observations by deleting outdated ones.
38
+ While the state-of-the-art algorithm, Active Bias (Chang et al., 2017), adopts the growing window,
39
+ we propose to use the sliding window in this paper.
40
+
41
+ |Historical observations|Col2|
42
+ |---|---|
43
+ |||
44
+ |||
45
+
46
+
47
+ Historical observations Historical observations
48
+
49
+ Growing Sliding
50
+
51
+ All available observations Outdated observations Recent observations
52
+
53
+
54
+ (a) Growing Window. (b) Sliding Window.
55
+
56
+ Figure 1: Two forms of handling the time-series observations.
57
+
58
+ In more detail, Active Bias recognizes uncertain samples based on the inconsistency of the predictions
59
+ in the entire history of past SGD iterations. Then, it emphasizes such uncertain samples by choosing
60
+ them with high probability for the next mini-batch. However, according to our experiments presented
61
+
62
+
63
+ -----
64
+
65
+ |… Horse Horse Horse|Col2|
66
+ |---|---|
67
+
68
+ |… Deer Deer Deer|Col2|
69
+ |---|---|
70
+
71
+
72
+ Images Inconsistent Predictions Consistent Predictions Sample Method
73
+
74
+ (Horse) History Uncertainty
75
+
76
+ Outdated Recent (too easy)
77
+
78
+ **High**
79
+
80
+ Horse Deer Horse Deer Deer Horse Deer … Horse Horse … Horse Active Bias
81
+
82
+ **Low**
83
+
84
+ Outdated Recent (too hard)
85
+
86
+ Deer Horse Horse Deer Horse Deer Horse … Deer Deer … Deer **High** **Recency Bias**
87
+
88
+ **Low**
89
+
90
+ Previous Training Iterations
91
+
92
+
93
+ Figure 2: The difference in sample uncertainty estimated by Active Bias and Recency Bias.
94
+
95
+ in Section 5.2, such uncertain samples slowed down the convergence speed of training, though they
96
+ ultimately reduced the generalization error. This weakness is attributed to the inherent limitation of
97
+ the growing window, where older observations could be too outdated (Torgo, 2011). In other words,
98
+ the outdated predictions no longer represent a network’s current behavior. As illustrated in Figure
99
+ 2, when the label predictions of two samples were inconsistent for a long time, Active Bias invariably
100
+ regards them as highly uncertain, although their recent label predictions become consistent along
101
+ with the network’s training progress. This characteristic evidently entails the risk of emphasizing
102
+ uninformative samples that are too easy or too hard at the current moment, thereby slowing down
103
+ the convergence speed of training.
104
+
105
+ Therefore, we propose a simple but effective batch selection method, called Recency Bias, that takes
106
+ advantage of the sliding window to evaluate the uncertainty in fresher observations. As opposed to
107
+ _Active Bias, Recency Bias excludes the outdated predictions by managing a sliding window of a fixed_
108
+ size and picks up the samples predicted inconsistently within the sliding window. Thus, as shown
109
+ in Figure 2, the two samples uninformative at the moment are no longer selected by Recency Bias
110
+ simply because their recent predictions are consistent. Consequently, since informative samples are
111
+ effectively selected throughout the training process, this strategy not only accelerates the training
112
+ speed but also leads to a more accurate network.
113
+
114
+ To validate the superiority of Recency Bias, two popular convolutional neural networks (CNNs) were
115
+ trained for two independent tasks: image classification and fine tuning. We compared Recency Bias
116
+ with not only random batch selection (baseline) but also two state-of-the-art batch selection strategies.
117
+ Compared with three batch selection strategies, Recency Bias provided a relative reduction of test
118
+ error by 1.81%–20.5% in a fixed wall-clock training time. At the same time, it significantly reduced
119
+ the execution time by 24.6%–59.3% to reach the same test error.
120
+
121
+ 2 RELATED WORK
122
+
123
+ Let D = {(xi, yi)|1 ≤ _i ≤_ _N_ _} be the entire training dataset composed of a sample xi with its_
124
+ true label yi, where N is the total number of training samples. Then, a straightforward strategy to
125
+ construct a mini-batch = (xi, yi) 1 _i_ _b_ is to select b samples uniformly at random (i.e.,
126
+ _M_ _{_ _|_ _≤_ _≤_ _}_
127
+ _P_ (xi ) = 1/N ) from the training dataset .
128
+ _|D_ _D_
129
+
130
+ Because not all samples have an equal impact on training, many research efforts have been devoted
131
+ to develop advanced sampling schemes. Bengio et al. (2009) first took easy samples and then
132
+ gradually increased the difficulty of samples using heuristic rules. Kumar et al. (2010) determined the
133
+ easiness of the samples using their prediction errors. Recently, Tsvetkov et al. (2016) used Bayesian
134
+ optimization to learn an optimal curriculum for training dense, distributed word representations.
135
+ Sachan & Xing (2016) emphasized that the right curriculum must introduce a small number of the
136
+ samples dissimilar to those previously seen. Fan et al. (2017) proposed a neural data filter based on
137
+ reinforcement learning to select training samples adaptively. However, it is common for deep learning
138
+ to emphasize hard samples because of the plethora of easy ones (Katharopoulos & Fleuret, 2018).
139
+
140
+ Loshchilov & Hutter (2016) proposed a difficulty-based sampling scheme, called Online Batch,
141
+ that uses the rank of the loss computed from previous epochs. Online Batch sorts the previously
142
+ computed losses of samples in descending order and exponentially decays the sampling probability
143
+ of a sample according to its rank r. Then, the r-th ranked sample x(r) is selected with the probability
144
+ dropping by a factor of exp log(se)/N, where se is the selection pressure parameter that affects
145
+ the probability gap between the most and the least important samples. When normalized to sum
146
+ to 1.0, the probability P (x(r) ; se) is defined by Eq. (1). It has been reported that _Online Batch_
147
+ _|D_
148
+
149
+
150
+ -----
151
+
152
+ accelerates the convergence of training but deteriorates the generalization error because of the
153
+ overfitting to hard training samples (Loshchilov & Hutter, 2016).
154
+
155
+ _r_
156
+ 1/ exp log(se)/N
157
+ _P_ (x(r) ; se) = _N_ _j_ (1)
158
+ _|D_ _j=1_ [1][/][ exp] log(se)/N
159
+
160
+ Most close to our work, Chang et al. (2017) devised anP _uncertainty_ -based sampling scheme, called
161
+ _Active Bias, that chooses uncertain samples with high probability for the next batch. Active Bias_
162
+ maintains the history _i_ that stores all h(yi _xi) before the current iteration t (i.e., growing window),_
163
+ _H[t][−][1]_ _|_
164
+ where h(yi|xi) is the softmax probability of a given sample xi for its true label yi. Then, it measures
165
+ the uncertainty of the sample xi by computing the variance over all h(yi _xi) in_ _i_ and draws the
166
+ _|_ _H[t][−][1]_
167
+ next mini-batch samples based on the normalized probability P (xi _,_ _i_ ; ϵ) in Eq. (2), where ϵ is
168
+ _|D_ _H[t][−][1]_
169
+ the smoothness constant to prevent the low variance samples from never being selected again. As
170
+ mentioned earlier in Section 1, Active Bias slows down the training process because the oldest part in
171
+ the history _i_ no longer represents the current behavior of the network.
172
+ _H[t][−][1]_
173
+
174
+ _P_ (xi|D, Hi[t][−][1]; ϵ) = _Nj=1stdˆ_ ˆstdi(Hji[t]([−][1]j) +) + ϵ _ϵ_ _,_ _stdˆ_ (Hi[t][−][1]) = vuvar _h(yi|xi)_ + _[var]h(iyi|xi)2_
175
+ _H[t][−][1]_ ut  _|H[t][−][1]|_ (2)
176
+ P 
177
+
178
+ For the completeness of the survey, we include the recent studies on submodular batch selection.
179
+ Joseph et al. (2019) and Wang et al. (2019) designed their own submodular objectives that cover
180
+ diverse aspects, such as sample redundancy and sample representativeness, for more effective
181
+ batch selection. Differently from their work, we explore the issue of truly uncertain samples in
182
+ an orthogonal perspective. Our uncertainty measure can be easily injected into their submodular
183
+ optimization framework as a measure of sample informativeness.
184
+
185
+ In Section 5, we will confirm that Recency Bias outperforms Online Batch and Active Bias, which are
186
+ regarded as two state-of-the-art adaptive batch selection methods for deep learning.
187
+
188
+ 3 _Recency Bias COMPONENTS_
189
+
190
+ 3.1 CRITERION OF AN UNCERTAIN SAMPLE
191
+
192
+ The main challenge of Recency Bias is to identify the samples whose recent label predictions are
193
+ highly inconsistent, which are neither too easy nor too hard at the moment. Thus, we adopt the
194
+ _predictive uncertainty (Song et al., 2019) in Definition 3.1 that uses the information entropy (Chandler,_
195
+ 1987) to measure the inconsistency of recent label predictions. Here, the sample with high predictive
196
+ uncertainty is regarded as uncertain and selected with high probability for the next mini-batch.
197
+ **Definition 3.1. (Predictive Uncertainty) Let ˆyt = Φ(xi, θt) be the predicted label of a sample xi at**
198
+ time t and Hxi (q) = {yˆt1 _, ˆyt2_ _, . . ., ˆytq_ _} be the label history of the sample xi that stores the predicted_
199
+ labels at the previous q times, where Φ is a neural network. The label history _xi_ (q) corresponds
200
+ _H_
201
+ to the sliding window of size q to compute the uncertainty of the sample xi. Next, p(yi _xi; q) is_
202
+ _|_
203
+ formulated such that it provides the probability of the label y ∈{1, 2, ..., k} estimated as the label of
204
+ the sample xi based on Hxi (q) as in Eq. (3), where [·] is the Iverson bracket[1].
205
+
206
+ _p(y_ _xi; q) =_ _yˆ∈Hxi_ (q)[[ˆ]y = y] (3)
207
+ _|_ P _xi_ (q)
208
+
209
+ _|H_ _|_
210
+
211
+ Then, to quantify the uncertainty of the sample xi, the predictive uncertainty F (xi; q) is defined by
212
+ Eq. (4), where δ is the standardization term to normalize the value to [0, 1].
213
+
214
+
215
+ _F_ (xi; q) = (1/δ)
216
+ _−_
217
+
218
+
219
+ _p(j_ _xi; q) log p(j_ _xi; q)_
220
+ _|_ _|_
221
+ _j=1_
222
+
223
+ X
224
+
225
+
226
+ (4)
227
+
228
+
229
+ _δ = −_ log (1/k) □
230
+
231
+ 1The Iverson bracket [p] returns 1 if p is true; 0 otherwise.
232
+
233
+
234
+ -----
235
+
236
+ 3.2 SAMPLING PROBABILITY FOR MINI-BATCH CONSTRUCTION
237
+
238
+ To construct next mini-batch samples, we assign the sampling probability according to the predictive
239
+ uncertainty in Definition 3.1. Motivated by Loshchilov & Hutter (2016), the sampling probability
240
+ of a given sample xi is exponentially decayed with its predictive uncertainty F (xi; q). In detail,
241
+ we adopt the quantization method (Chen & Wornell, 2001) and use the quantization index to decay
242
+ the sampling probability. The index is obtained by the simple quantizer Q in Eq. (5), where ∆ is
243
+ the quantization step size. Compared with the rank-based index (Loshchilov & Hutter, 2016), the
244
+ quantization index is known to well reflect the difference in actual values (Widrow et al., 1996).
245
+
246
+ _Q_ _F_ (xi; q) = 1 _F_ (xi; q) _/∆_ _, 0_ _F_ (xi; q) 1 (5)
247
+ _⌈_ _−_ _⌉_ _≤_ _≤_
248
+
249
+ In Eq. (5), we set ∆ to be 1/N such that the index is bounded to  _N (the total number of samples)._
250
+ Then, the sampling probability P (xi ; se) is defined as in Eq. (6). The higher the predictive
251
+ _|D_
252
+ uncertainty, the smaller the quantization index. Therefore, a higher sampling probability is assigned
253
+ for uncertain samples in Eq. (6).
254
+
255
+ 1/ exp log(se)/N _Q(F (xi;q))_
256
+ _P_ (xi|D; se) = _N_ _Q(F (xj_ ;q)) (6)
257
+
258
+ _j=1_ [1][/][ exp] log(se)/N
259
+
260
+ Meanwhile, it is known that using only some part of training data exacerbates the overfitting problemP 
261
+ at a late stage of training (Loshchilov & Hutter, 2016; Zhou & Bilmes, 2018). Thus, to alleviate
262
+ the problem, we include more training samples as the training progresses by exponentially decaying
263
+ the selection pressure se as in Eq. (7). At each epoch e from e0 to eend, the selection pressure
264
+ _se exponentially decreases from se0 to 1. Because this technique gradually reduces the sampling_
265
+ probability gap between the most and the least uncertain samples, more diverse samples are selected
266
+ for the next mini-batch at a later epoch. When the selection pressure se becomes 1, the mini-batch
267
+ samples are randomly chosen from the entire dataset.
268
+
269
+ 0
270
+ _se = se0_ exp log (1/se0 )/(eend − _e0)_ (7)
271
+  [][e][−][e]
272
+
273
+ 4 _Recency Bias ALGORITHM_
274
+
275
+ **Algorithm 1 Recency Bias Algorithm**
276
+
277
+ INPUT: : data, epochs, b: batch size, q: window size, se0 : initial selection pressure, γ: warm-up
278
+ _D_
279
+ OUTPUT: θt: model parameter
280
+
281
+ 1: t ← 1;
282
+ 2: θt ← Initialize the model parameter;
283
+ 3: for i = 1 to epochs do
284
+ 4: /* Sampling Probability Derivation */
285
+
286
+ 5: **if i > γ then**
287
+
288
+ 6: _se ←_ Decay_Selection_Pressure(se0, i); /* Decaying se by Eq. (7) */
289
+
290
+ 7: **for m = 1 to N do** /* Updating the index and the sampling probability in a batch */
291
+
292
+ 8: _q_dict[xm] = Q_ _F_ (xm; q) ; /* By Eq. (5) */
293
+
294
+
295
+ 9: _p_table ←_ Compute_Prob(q_dict, se); /* By Eq. (6) */
296
+
297
+ 10: /* Network Training */
298
+
299
+ 11: **for j = 1 to N/b do** /* Mini-batch */
300
+
301
+ 12: **if i ≤** _γ then_ /* Warm-up */
302
+
303
+ 13: (x1, y1), . . ., (xb, yb) Randomly select next mini-batch samples;
304
+ _{_ _} ←_
305
+
306
+ 14: **else /* Adaptive batch selection */**
307
+
308
+ 15: (x1, y1), . . ., (xb, yb) Select next mini-batch samples based on p_table;
309
+ _{_ _} ���_
310
+
311
+ 16: _losses, labels_ Inference_Step( (x1, y1), . . ., (xb, yb),θt); /* Forward */
312
+ _←_ _{_ _}_
313
+
314
+ 17: _θt+1 ←_ SGD_Step(losses, θt); /* Backward */
315
+
316
+ 18: Update_Label_History(labels); /* By Definition 3.1 */
317
+
318
+ 19: _t ←_ _t + 1;_
319
+
320
+ 20: return θt;
321
+
322
+ Algorithm 1 describes the overall procedure of Recency Bias. The algorithm requires a warm-up
323
+ period of γ epochs because the quantization index for each sample is not confirmed yet. During
324
+ the warm-up period, which should be at least q epochs (γ ≥ _q) to obtain the label history of size_
325
+
326
+
327
+ -----
328
+
329
+ _q, randomly selected mini-batch samples are used for the network update (Lines 12–13). After the_
330
+ warm-up period, the algorithm decays the selection pressure se and updates not only the quantization
331
+ index but also the sampling probability in a batch at the beginning of each epoch (Lines 4–9).
332
+ Subsequently, the uncertain samples are selected for the next mini-batch according to the updated
333
+ sampling probability (Line 14–15), and then the label history is updated along with the network
334
+ update (Lines 16–19).
335
+
336
+ Overall, the key technical novelty of Recency Bias is to incorporate the notion of a sliding win_dow (Line 8) rather than a growing window into adaptive batch selection, thereby improving both_
337
+ training speed and generalization error.
338
+
339
+ **Time Complexity: The main “additional” cost of Recency Bias is the derivation of the sampling**
340
+ probability for each sample (Lines 4–9). Because only simple mathematical operations are needed
341
+ per sample, its time complexity is linear to the number of samples (i.e., O(N )), which is negligible
342
+ compared with that of the forward and backward steps of a complex network (Lines 16–17). Therefore,
343
+ we contend that Recency Bias does not add the complexity of an underlying optimization algorithm.
344
+
345
+ 5 EVALUATION
346
+
347
+ We empirically show the improvement of Recency Bias over not only Random Batch (baseline) but also
348
+ _Online Batch (Loshchilov & Hutter, 2016) and Active Bias (Chang et al., 2017), which are two state-_
349
+ of-the-art adaptive batch selections. In particular, we elaborate on the effect of the sliding window
350
+ approach (Recency Bias) compared with the growing window approach (Active Bias). Random Batch
351
+ selects next mini-batch samples uniformly at random from the entire dataset. Online Batch selects hard
352
+ samples based on the rank of the loss computed from previous epochs. Active Bias selects uncertain
353
+ samples with high variance of true label probabilities in the growing window. All the algorithms
354
+ were implemented using TensorFlow 1.8.0 and executed using a single NVIDIA Titan Volta GPU.
355
+ [For reproducibility, we provide the source code at https://github.com/anonymized.](https://github.com/anonymized)
356
+
357
+ Image classification and fine-tuning tasks were performed to validate the superiority of Recency Bias.
358
+ Because fine-tuning is used to quickly adapt to a new dataset, it is suitable to reap the benefit of fast
359
+ training speed. In support of reliable evaluation, we repeated every task thrice and reported the average
360
+ and standard error of the best test errors. The best test error in a given time has been widely used for
361
+ the studies on fast and accurate training (Katharopoulos & Fleuret, 2018; Loshchilov & Hutter, 2016).
362
+
363
+ 5.1 ANALYSIS ON SELECTED MINI-BATCH SAMPLES
364
+
365
+ For an in-depth analysis on selected samples, we plot the loss distribution of mini-batch samples
366
+ selected from CIFAR-10 by four different strategies in Figure 3. (i) The distribution of Online Batch
367
+ is the most skewed toward high loss by the design principle of selecting hard samples. (ii) Active Bias
368
+ emphasizes moderately hard samples at an early training stage in considering that its loss distribution
369
+ lies between those of Random Batch and Online Batch. However, owing to the outdated predictions
370
+ caused by the growing window, the proportion of easy samples with low loss increases at a late
371
+ training stage. These easy samples, which are misclassified as uncertain at that stage, tend to make the
372
+ convergence of training slow down. (iii) In contrast to Active Bias, by virtue of the sliding window,
373
+ the distribution of Recency Bias lies between those of Random Batch and Online Batch regardless of
374
+ the training stage. Consequently, Recency Bias continues to highlight the moderately hard samples,
375
+ which are likely to be informative, during the training process.
376
+
377
+ Random Batch
378
+ Online Batch
379
+ Active Bias
380
+ (Growing window)
381
+ Recency Bias
382
+ (Sliding window)
383
+
384
+ Loss (Log-scale) Loss (Log-scale)
385
+
386
+
387
+ (a) Early Stage (30%). (b) Late Stage (70%).
388
+
389
+ Figure 3: The loss distribution of mini-batch samples selected by four batch selection strategies: (a)
390
+ and (b) show the loss distribution at the 30% and 70% of total training epochs, respectively.
391
+
392
+
393
+ -----
394
+
395
+ 5.2 TASK I: IMAGE CLASSIFICATION
396
+
397
+ **Experiment Setting: We trained DenseNet (L=40, k=12) and ResNet (L=50) with a momentum**
398
+ optimizer and an SGD optimizer on three benchmark datasets: MNIST (10 classes)[2], classification
399
+ of handwritten digits (LeCun, 1998), and CIFAR-10 (10 classes)[3] and CIFAR-100 (100 classes)[3],
400
+ classification of a subset of 80 million categorical images (Krizhevsky et al., 2014). Specifically, we
401
+ used data augmentation, batch normalization, a momentum of 0.9, and a batch size of 128. As for the
402
+ algorithm parameters, we fixed the window size q = 10 and the initial selection pressure se0 = 100,[4]
403
+ which were the best values found by the grid search (see Appendix A for details). The warm-up
404
+ epoch γ was set to be 15. To reduce the performance variance caused by randomly initialized model
405
+ parameters, all parameters were shared by all algorithms during the warm-up period. Regarding
406
+ the training schedule, we trained the network for 40, 000 iterations and used an initial learning rate
407
+ of 0.1, which was divided by 10 at 50% and 75% of the total number of training iterations.
408
+
409
+ **Results: Figure 4 shows the convergence curves of training loss and test error for four batch selection**
410
+ strategies using DenseNet and a momentum optimizer. In order to highlight the improvement of
411
+ _Recency Bias over the baseline (Random Batch), their lines are dark colored. The best test errors in_
412
+ Figures 4(b), 4(d), and 4(f) are summarized on the left side of Table 1.
413
+
414
+ In general, Recency Bias achieved the most accurate network while accelerating the training process
415
+ on all datasets. The training loss of Recency Bias converged faster (Figures 4(a), 4(c), and 4(e))
416
+ without the increase in the generalization error, thereby achieving the lower test error (Figures 4(b),
417
+ 4(d), and 4(f)). In contrast, the test error of Online Batch was not the best even if its training loss
418
+ converged the fastest among all strategies. As the training difficulty increased from CIFAR-10 to
419
+ CIFAR-100, the test error of Online Batch became even worse than that of Random Batch. That
420
+ is, emphasizing hard samples accelerated the training step but made the network overfit to hard
421
+ samples. Meanwhile, Active Bias was prone to make the network better generalized on test data.
422
+ In CIFAR-10, despite its highest training loss, the test error of Active Bias was better than that of
423
+ _Random Batch. However, Active Bias slowed down the training process because of the limitation_
424
+ of growing windows, as discussed in Section 5.1. We note that, although both Recency Bias and
425
+ _Active Bias exploited uncertain samples, only Recency Bias based on sliding windows succeeded_
426
+ to not only speed up the training process but also reduce the generalization error.
427
+
428
+ The results of the best test error for ResNet or an SGD optimizer are summarized in Tables 1 and
429
+ 2 (see Appendix B for more details). Regardless of a neural network and an optimizer, Recency
430
+ _Bias achieved the lowest test error except in MNIST with an SGD optimizer. The improvement of_
431
+ _Recency Bias over the others was higher with an SGD optimizer than with a momentum optimizer._
432
+
433
+ Table 1: The best test errors (%) of four batch selection strategies using DenseNet.
434
+
435
+ |Optimizer|Momentum in Figure 4|Col3|Col4|SGD in Figure 7 (Appendix B.1)|Col6|Col7|
436
+ |---|---|---|---|---|---|---|
437
+ |Method|MNIST|CIFAR-10|CIFAR-100|MNIST|CIFAR-10|CIFAR-100|
438
+ |Random Batch|0.527 ± 0.03|7.33 ± 0.09|28.0 ± 0.16|1.23 ± 0.03|14.9 ± 0.09|40.2 ± 0.06|
439
+ |Online Batch|0.514 ± 0.01|7.00 ± 0.10|28.4 ± 0.25|0.765 ± 0.02|13.5 ± 0.02|40.7 ± 0.12|
440
+ |Active Bias|0.616 ± 0.03|7.07 ± 0.04|27.9 ± 0.11|0.679 ± 0.02|14.2 ± 0.25|42.9 ± 0.05|
441
+ |Recency Bias|0.490 ± 0.02|6.60 ± 0.02|27.1 ± 0.19|0.986 ± 0.06|13.2 ± 0.11|38.7 ± 0.11|
442
+
443
+
444
+
445
+ Table 2: The best test errors (%) of four batch selection strategies using ResNet.
446
+
447
+ |Optimizer|Momentum in Figure 8 (Appendix B.2)|Col3|Col4|SGD in Figure 9 (Appendix B.3)|Col6|Col7|
448
+ |---|---|---|---|---|---|---|
449
+ |Method|MNIST|CIFAR-10|CIFAR-100|MNIST|CIFAR-10|CIFAR-100|
450
+ |Random Batch|0.636 ± 0.04|10.2 ± 0.12|33.2 ± 0.07|1.16 ± 0.03|12.7 ± 0.09|40.1 ± 0.16|
451
+ |Online Batch|0.666 ± 0.05|10.1 ± 0.05|33.4 ± 0.01|0.890 ± 0.03|12.2 ± 0.08|40.7 ± 0.09|
452
+ |Active Bias|0.613 ± 0.04|10.6 ± 0.08|34.2 ± 0.07|0.804 ± 0.01|13.5 ± 0.07|45.6 ± 0.07|
453
+ |Recency Bias|0.607 ± 0.01|9.79 ± 0.04|32.4 ± 0.04|0.972 ± 0.03|11.6 ± 0.09|38.9 ± 0.14|
454
+
455
+
456
+
457
+ [2http://yann.lecun.com/exdb/mnist](http://yann.lecun.com/exdb/mnist)
458
+ [3https://www.cs.toronto.edu/~kriz/cifar.html](https://www.cs.toronto.edu/~kriz/cifar.html)
459
+ 4Online Batch also used the same decaying selection pressure value.
460
+
461
+
462
+ -----
463
+
464
+ |Col1|Col2|Col3|Col4|Col5|Col6|Col7|Col8|Col9|Col10|
465
+ |---|---|---|---|---|---|---|---|---|---|
466
+ ||Random Batch Online||||Batch Active Bias Recency Bias|||||
467
+ |E-01 E-02 E-03|||||3.6% Error 1.2% Test|||||
468
+ |||||||||||
469
+ |||||||||||
470
+ |||||||||||
471
+ |||||||||||
472
+
473
+
474
+ 2125 4250 6375 8500
475
+
476
+ Time (s)
477
+
478
+
479
+ 2125 4250 6375 8500
480
+
481
+ Time (s)
482
+
483
+
484
+ 0.90.80.70.60.50.40.30.20.110
485
+
486
+
487
+
488
+
489
+ |Col1|Col2|Col3|Col4|
490
+ |---|---|---|---|
491
+ |||||
492
+ |||||
493
+
494
+ |0|Col2|Col3|Col4|
495
+ |---|---|---|---|
496
+ |||||
497
+ |||||
498
+
499
+
500
+ |(a) MNIST Training Loss. (b) MNIST Test Error.|Col2|Col3|Col4|Col5|
501
+ |---|---|---|---|---|
502
+ |(a) MNIST Training Loss. (b) MNIST Test Error. 16E-01 20 40 60 80 26.0%10 0 4E-01 Error 13.0% Test 4E-02 0E-03 6.5% 0 2500 5000 7500 10000 0 2500 5000 7500 100 Time (s) Time (s) (c) CIFAR-10 Training Loss. (d) CIFAR-10 Test Error. 4E+00 54.0% Error 0E-01 Test|||||
503
+ ||||||
504
+
505
+
506
+ 2500 5000 7500 10000
507
+
508
+
509
+ 2.4E+00
510
+
511
+ 6.0E-01
512
+
513
+
514
+ 1.5E-01
515
+
516
+
517
+ 27.0%
518
+
519
+
520
+ 2500 5000 7500 10000
521
+
522
+ Time (s)
523
+
524
+
525
+ Time (s)
526
+
527
+
528
+ (e) CIFAR-100 Training Loss. (f) CIFAR-100 Test Error.
529
+
530
+ Figure 4: Convergence curves of four batch selection strategies using DenseNet with momentum.
531
+
532
+
533
+ 5.3 TASK II: FINE-TUNING
534
+
535
+ **Experiment Setting: We prepared DenseNet (L=121, k=32) previously trained on ImageNet (Deng**
536
+ et al., 2009) and then fine-tuned the network on two benchmark datasets: MIT-67 (67 classes)[5],
537
+ classification of indoor scenes (Quattoni & Torralba, 2009), and Food-100 (100 classes)[6], classification of popular foods in Japan (Kawano & Yanai, 2014). After replacing the last classification
538
+ layer, the network was trained end-to-end for 50 epochs with a batch size 32 and a constant learning
539
+ rate 2 × 10[−][4]. Data augmentation was not applied here. The other configurations were the same
540
+ as those in Section 5.2.
541
+
542
+ **Results on Test Error: Figure 5 shows the convergence curves of training loss and test error for**
543
+ the fine-tuning task on MIT-67 and Food-100. Overall, all convergence curves showed similar trends
544
+ to those of the classification task in Figure 4. Only Recency Bias converged faster than Random
545
+ _Batch in both training loss and test error. Online Batch converged the fastest in training loss, but_
546
+ its test error was rather higher than Random Batch owing to the overfitting. Active Bias converged the
547
+
548
+
549
+ [5http://web.mit.edu/torralba/www/indoor.html](http://web.mit.edu/torralba/www/indoor.html)
550
+ [6http://foodcam.mobi/dataset100.html](http://foodcam.mobi/dataset100.html)
551
+
552
+
553
+ -----
554
+
555
+ |Col1|Col2|Col3|Col4|Col5|
556
+ |---|---|---|---|---|
557
+ ||||||
558
+ |Time Redu|ction: 24.6|%|||
559
+
560
+
561
+ Random Batch Online Batch Active Bias Recency Bias
562
+
563
+ 1.9E+00 39.0%
564
+
565
+ 6.3E-012.1E-01 Test Error 35.0%31.0%
566
+
567
+ Training Loss
568
+
569
+ Time Reduction: 24.6%
570
+
571
+ 7.0E-02 27.0%
572
+
573
+ 0 1500 3000 4500 6000 0 1500 3000 4500 6000
574
+
575
+ 0.90.80.70.60.50.40.30.20.110 (a) MIT-67 Training Loss.Time (s) (b) MIT-67 Test Error.Time (s)
576
+
577
+ 1.6E+001 20 40 60 80 44.0%10
578
+
579
+ 0
580
+
581
+ 40.0%
582
+
583
+
584
+ 0.90.80.70.60.50.40.30.20.110
585
+
586
+
587
+ 8.0E-01
588
+
589
+ 4.0E-01
590
+
591
+
592
+ 36.0%
593
+
594
+ 32.0%
595
+
596
+
597
+ 2.0E-01
598
+
599
+ |20|4|0|60|Col5|
600
+ |---|---|---|---|---|
601
+ ||||||
602
+ ||||||
603
+
604
+ |0|Col2|Col3|Col4|Col5|
605
+ |---|---|---|---|---|
606
+ |0|||||
607
+ ||||||
608
+ |Time Redu|ction: 26.1|%|||
609
+
610
+
611
+ 2000 4000 6000 8000
612
+
613
+ Time (s)
614
+
615
+
616
+ 2000 4000 6000 8000
617
+
618
+ Time (s)
619
+
620
+
621
+ (c) Food-100 Training Loss. (d) Food-100 Test Error.
622
+
623
+ Figure 5: Convergence curves for fine-tuning on two benchmark datasets.
624
+
625
+
626
+ Table 3: Recency Bias’s reduction in training time over other batch selection strategies.
627
+
628
+
629
+ |Method|MIT-67|FOOD-100|
630
+ |---|---|---|
631
+ |Random Batch|(5, 218 −3, 936)/5, 218 × 100 = 24.6%|(7, 263 −5, 365)/7, 263 × 100 = 26.1%|
632
+ |Online Batch|(6, 079 −3, 823)/6, 079 × 100 = 37.1%|(8, 333 −3, 685)/8, 333 × 100 = 55.8%|
633
+ |Active Bias|(5, 738 −3, 032)/5, 738 × 100 = 47.2%|(7, 933 −3, 227)/7, 933 × 100 = 59.3%|
634
+
635
+
636
+ slowest in both training loss and test error. Quantitatively, compared with Random Batch, Recency
637
+ _Bias reduced the test error by 2.88% and 1.81% in MIT-67 and Food-100, respectively._
638
+
639
+ **Results on Training Time: Moreover, to assess the performance gain in training time, we computed**
640
+ the reduction in the training time taken to reach the same error. For example, in Figure 5(b), the
641
+ best test error of 28.8% achieved in 5, 218 seconds by Random Batch could be achieved only in
642
+ 3, 936 seconds by Recency Bias; thus, Recency Bias improved the training time by 24.6%. Table
643
+ 3 summarizes the reduction in the training time of Recency Bias over three other batch selection
644
+ strategies. Notably, Recency Bias improved the training time by 24.6%–47.2% and 26.1%–59.3% in
645
+ fine-tuning MIT-67 and FOOD-100 datasets, respectively.
646
+
647
+ 6 CONCLUSION
648
+
649
+
650
+ In this paper, we presented a novel adaptive batch selection algorithm called Recency Bias that
651
+ emphasizes predictively uncertain samples for accelerating the training of neural networks. Toward
652
+ this goal, the predictive uncertainty of each sample is evaluated using its recent label predictions
653
+ managed by a sliding window of a fixed size. Then, uncertain samples at the moment are selected with
654
+ high probability for the next mini-batch. We conducted extensive experiments on both classification
655
+ and fine-tuning tasks. The results showed that Recency Bias is effective in reducing the training
656
+ time as well as the best test error. It was worthwhile to note that using all historical observations to
657
+ estimate the uncertainty has the side effect of slowing down the training process. Overall, a merger of
658
+ uncertain samples and sliding windows greatly improves the power of adaptive batch selection.
659
+
660
+
661
+ -----
662
+
663
+ REFERENCES
664
+
665
+ Yoshua Bengio, Jérôme Louradour, Ronan Collobert, and Jason Weston. Curriculum learning. In
666
+ _ICML, pp. 41–48, 2009._
667
+
668
+ David Chandler. Introduction to modern statistical mechanics. Oxford University Press, 1987.
669
+
670
+ Haw-Shiuan Chang, Erik Learned-Miller, and Andrew McCallum. Active Bias: Training more
671
+ accurate neural networks by emphasizing high variance samples. In NeurIPS, pp. 1002–1012,
672
+ 2017.
673
+
674
+ Brian Chen and Gregory W Wornell. Quantization index modulation: A class of provably good
675
+ methods for digital watermarking and information embedding. IEEE Trans. on Information Theory,
676
+ 47(4):1423–1443, 2001.
677
+
678
+ Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale
679
+ hierarchical image database. In CVPR, pp. 248–255, 2009.
680
+
681
+ Yang Fan, Fei Tian, Tao Qin, and Tie-Yan Liu. Neural data filter for bootstrapping stochastic gradient
682
+ descent. In ICLR, 2017.
683
+
684
+ Bo Han, Quanming Yao, Xingrui Yu, Gang Niu, Miao Xu, Weihua Hu, Ivor Tsang, and Masashi
685
+ Sugiyama. Co-teaching: Robust training of deep neural networks with extremely noisy labels. In
686
+ _NeurIPS, pp. 8527–8537, 2018._
687
+
688
+ KJ Joseph, Krishnakant Singh, Vineeth N Balasubramanian, et al. Submodular batch selection for
689
+ training deep neural networks. In IJCAI, pp. 2677–3683, 2019.
690
+
691
+ Angelos Katharopoulos and François Fleuret. Not all samples are created equal: Deep learning with
692
+ importance sampling. In ICML, pp. 2525–2534, 2018.
693
+
694
+ Y. Kawano and K. Yanai. Food image recognition with deep convolutional features. In UbiComp,
695
+ 2014.
696
+
697
+ Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. CIFAR-10 and CIFAR-100 datasets, 2014.
698
+ [https://www.cs.toronto.edu/~kriz/cifar.html.](https://www.cs.toronto.edu/~kriz/cifar.html)
699
+
700
+ M Pawan Kumar, Benjamin Packer, and Daphne Koller. Self-paced learning for latent variable
701
+ models. In NeurIPS, pp. 1189–1197, 2010.
702
+
703
+ [Yann LeCun. The MNIST database of handwritten digits, 1998. http://yann.lecun.com/](http://yann.lecun.com/exdb/mnist)
704
+ [exdb/mnist.](http://yann.lecun.com/exdb/mnist)
705
+
706
+ Ilya Loshchilov and Frank Hutter. Online batch selection for faster training of neural networks. In
707
+ _ICLR, 2016._
708
+
709
+ Ariadna Quattoni and Antonio Torralba. Recognizing indoor scenes. In CVPR, pp. 413–420, 2009.
710
+
711
+ Mrinmaya Sachan and Eric Xing. Easy questions first? A case study on curriculum learning for
712
+ question answering. In ACL, pp. 453–463, 2016.
713
+
714
+ Abhinav Shrivastava, Abhinav Gupta, and Ross Girshick. Training region-based object detectors
715
+ with online hard example mining. In CVPR, pp. 761–769, 2016.
716
+
717
+ Hwanjun Song, Minseok Kim, and Jae-Gil Lee. SELFIE: Refurbishing unclean samples for robust
718
+ deep learning. In ICML, pp. 5907–5915, 2019.
719
+
720
+ Luis Torgo. Data mining with R: learning with case studies. Chapman and Hall/CRC, 2011.
721
+
722
+ Yulia Tsvetkov, Manaal Faruqui, Wang Ling, Brian MacWhinney, and Chris Dyer. Learning the
723
+ curriculum with bayesian optimization for task-specific word representation learning. In ACL, pp.
724
+ 130–139, 2016.
725
+
726
+ Shengjie Wang, Wenruo Bai, Chandrashekhar Lavania, and Jeff Bilmes. Fixing mini-batch sequences
727
+ with hierarchical robust partitioning. In AISTATS, pp. 3352–3361, 2019.
728
+
729
+
730
+ -----
731
+
732
+ Bernard Widrow, Istvan Kollar, and Ming-Chang Liu. Statistical theory of quantization. IEEE
733
+ _Transactions on instrumentation and measurement, 45(2):353–361, 1996._
734
+
735
+ Tianyi Zhou and Jeff Bilmes. Minimax curriculum learning: Machine teaching with desirable
736
+ difficulties and scheduled diversity. In ICLR, 2018.
737
+
738
+
739
+ -----
740
+
741
+ A HYPERPARAMETER SELECTION
742
+
743
+ _Recency Bias receives the two hyperparameters: (i) the initial selection pressure se0 that determines_
744
+ the sampling probability gap between the most and the least uncertain samples and (ii) the window
745
+ size q that determines how many recent label predictions are involved in predicting the uncertainty.
746
+ To decide the best hyperparameters, we trained ResNet (L=50) on CIFAR-10 and CIFAR-100 with a
747
+ momentum optimizer. For hyperparameters selection, the two hyperparameters were chosen in a grid
748
+ _se0_ 1, 10, 100, 1000 and q 5, 10, 15 .
749
+ _∈{_ _}_ _∈{_ _}_
750
+
751
+
752
+
753
+ |Window|Size|Col3|
754
+ |---|---|---|
755
+ |q=5 q=10 q=15|||
756
+
757
+
758
+ 10.4% 33.5%
759
+
760
+ Window Size
761
+
762
+ 10.1% 33.0% q=5
763
+
764
+ q=10
765
+
766
+ 9.8% 32.5%
767
+
768
+ Best Test Error q=15
769
+
770
+ 9.5% 32.0%
771
+
772
+ 1 10 100 1000 1 10 100 1000
773
+
774
+ Initial Selection Pressure (𝑆𝑒0) Initial Selection Pressure (𝑆𝑒0)
775
+
776
+
777
+ (a) CIFAR-10. (b) CIFAR-100.
778
+
779
+ Figure 6: Grid search on CIFAR-10 and CIFAR-100 datasets using ResNet.
780
+
781
+ Figure 6 shows the test errors of Recency Bias obtained by the grid search on the two datasets.
782
+ Regarding the initial selection pressure se0, the lowest test error was typically achieved when the
783
+ _se0 value was 100. As for the window size q, the test error was almost always the lowest when the q_
784
+ value was 10. Similar trends were observed for the other combinations of a neural network and an
785
+ optimizer. Therefore, in all experiments, we set se0 to be 100 and q to be 10.
786
+
787
+
788
+ -----
789
+
790
+ |GENERALIZATION OF Recency Bias|Col2|Col3|Col4|Col5|Col6|Col7|Col8|Col9|Col10|Col11|Col12|
791
+ |---|---|---|---|---|---|---|---|---|---|---|---|
792
+ |CONVERGENCE CURVES USING DENSENET WITH SGD 7 shows the convergence curves of training loss and test error for four batch selection strate DenseNet and an SGD optimizer, which corresponds to the right side of Table 1.||||||||||||
793
+ ||eNet and an SGD optimizer, whic|||||||||||
794
+ ||Random Batch Online|||||Batch Active Bias Recency Bias||||||
795
+ |E-01 E-02 E-02||||||4.8% 2.4% Error Test 1.2%||||||
796
+ |||||||||||||
797
+ |||||||||||||
798
+ |||||||||||||
799
+
800
+
801
+ 2000 4000 6000 8000
802
+
803
+ Time (s)
804
+
805
+
806
+ 2000 4000 6000 8000
807
+
808
+ Time (s)
809
+
810
+
811
+ 0.90.80.70.60.50.40.30.20.110
812
+
813
+
814
+
815
+
816
+ |(a) MNIST Training Loss. (b) MNIST Test Error.|Col2|Col3|Col4|Col5|
817
+ |---|---|---|---|---|
818
+ |(a) MNIST Training Loss. (b) MNIST Test Error. 10E+00 20 40 60 80 48.0%10 0 2E-01 Error 24.0% Test 6E-01|||||
819
+ ||||||
820
+ ||||||
821
+
822
+
823
+ 3.2E+00
824
+
825
+
826
+ 12.0%
827
+
828
+ |Col1|Col2|Col3|Col4|
829
+ |---|---|---|---|
830
+ |||||
831
+ |||||
832
+
833
+
834
+ 2500 5000 7500 10000
835
+
836
+ Time (s)
837
+
838
+
839
+ 2500 5000 7500 10000
840
+
841
+
842
+ (c) CIFAR-10 Training Loss. (d) CIFAR-10 Test Error.
843
+
844
+ 70.0%
845
+
846
+
847
+ 1.6E+00
848
+
849
+ 8.0E-01
850
+
851
+
852
+ 35.0%
853
+
854
+ |(c) CIFAR-10 Training Loss.|Col2|Col3|Col4|
855
+ |---|---|---|---|
856
+ |||||
857
+ |||||
858
+ |||||
859
+
860
+ |Time (s)|Col2|Col3|Col4|
861
+ |---|---|---|---|
862
+ |(d) CIFAR-10 Test Error.||||
863
+ |||||
864
+
865
+
866
+ 2500 5000 7500 10000
867
+
868
+ Time (s)
869
+
870
+
871
+ 2500 5000 7500 10000
872
+
873
+ Time (s)
874
+
875
+
876
+ (e) CIFAR-100 Training Loss. (f) CIFAR-100 Test Error.
877
+
878
+ Figure 7: Convergence curves of four batch selection strategies using DenseNet with SGD.
879
+
880
+
881
+ -----
882
+
883
+ |Col1|et and a momentum optimizer, w|Col3|Col4|Col5|
884
+ |---|---|---|---|---|
885
+ ||Random Batch Online|||Batch Active Bias Recency Bias|
886
+ ||||||
887
+ ||||||
888
+ ||||||
889
+ ||||||
890
+
891
+
892
+
893
+ 0.90.80.70.60.50.40.30.20.110
894
+
895
+
896
+
897
+
898
+ |20|40|60|
899
+ |---|---|---|
900
+ ||||
901
+ ||||
902
+ ||||
903
+
904
+ |0|Col2|Col3|
905
+ |---|---|---|
906
+ ||||
907
+ ||||
908
+
909
+
910
+
911
+
912
+ |CONVERGENCE CURVES USING RESNET WITH MOMENTUM e 8 shows the convergence curves of training loss and test error for four batch selection strate ResNet and a momentum optimizer, which corresponds to the left side of Table 2. Random Batch Online Batch Active Bias Recency Bias 5E-01 4.5% 9E-02 Error 1.5% Test 4E-03 0E-04 0.5% 0 2300 4600 6900 0 2300 4600 69 Time (s) Time (s) (a) MNIST Training Loss. (b) MNIST Test Error. 17E+00 20 40 60 80 36.0%10 0 5E-01 Error 18.0% Test 5E-02 0E-03 9.0% 0 2900 5800 8700 0 2900 5800 870 Time (s) Time (s) (c) CIFAR-10 Training Loss. (d) CIFAR-10 Test Error. 3E+00 64.0% 2E-01 Error Test 2E-01|Col2|Col3|Col4|
913
+ |---|---|---|---|
914
+ |||||
915
+
916
+
917
+ 2900 5800 8700
918
+
919
+ Time (s)
920
+
921
+
922
+ 2900 5800 8700
923
+
924
+ Time (s)
925
+
926
+
927
+ (e) CIFAR-100 Training Loss. (f) CIFAR-100 Test Error.
928
+
929
+ Figure 8: Convergence curves of four batch selection strategies using ResNet with momentum.
930
+
931
+
932
+ -----
933
+
934
+ |CONVERGENCE CURVES USING RESNET WITH SGD 9 shows the convergence curves of training loss and test error for four batch selection strate ResNet and an SGD optimizer, which corresponds to the right side of Table 2.|Col2|Col3|Col4|Col5|Col6|Col7|Col8|
935
+ |---|---|---|---|---|---|---|---|
936
+ ||ows the convergence curves of trai et and an SGD optimizer, which|||||||
937
+ ||Random Batch Online|||Batch Active Bias Recency Bias||||
938
+ |E-01 E-02 E-02||||6.3% Error 2.1% Test||||
939
+ |||||||||
940
+ |||||||||
941
+ |||||||||
942
+ |||||||||
943
+
944
+
945
+ 2300 4600 6900
946
+
947
+
948
+ 2300 4600 6900
949
+
950
+ |Time (s) Time (s) (a) MNIST Training Loss. (b) MNIST Test Error. 1 20 40 60 80 44.0%10 4E+00 0 5E-01 Error 22.0% Test 5E-01|Col2|Col3|Col4|
951
+ |---|---|---|---|
952
+ |||||
953
+ |||||
954
+
955
+
956
+
957
+ 2900 5800 8700
958
+
959
+
960
+ 0.90.80.70.60.50.40.30.20.110
961
+
962
+
963
+ 0
964
+
965
+ 22.0%
966
+
967
+ Test Error
968
+
969
+ 11.0%
970
+
971
+
972
+ 1.4E+001
973
+
974
+ 4.5E-01
975
+
976
+
977
+ 1.5E-01
978
+
979
+ 5.0E-02
980
+
981
+ |20|40|60|
982
+ |---|---|---|
983
+ ||||
984
+ ||||
985
+ ||||
986
+
987
+
988
+ 2900 5800 8700
989
+
990
+
991
+
992
+ (c) CIFAR-10 Training Loss. (d) CIFAR-10 Test Error.
993
+
994
+ 76.0%
995
+
996
+
997
+ 3.6E+00
998
+
999
+ 1.2E+00
1000
+
1001
+
1002
+ 4.0E-01
1003
+
1004
+
1005
+ 38.0%
1006
+
1007
+ |Time (s)|Col2|Col3|
1008
+ |---|---|---|
1009
+ |(c) CIFAR-10 Training Loss.|||
1010
+ ||||
1011
+ ||||
1012
+
1013
+ |Time (s)|Col2|Col3|
1014
+ |---|---|---|
1015
+ |(d) CIFAR-10 Test Error.|||
1016
+ ||||
1017
+
1018
+
1019
+ 2900 5800 8700
1020
+
1021
+ Time (s)
1022
+
1023
+
1024
+ 2900 5800 8700
1025
+
1026
+ Time (s)
1027
+
1028
+
1029
+ (e) CIFAR-100 Training Loss. (f) CIFAR-100 Test Error.
1030
+
1031
+ Figure 9: Convergence curves of four batch selection strategies using ResNet with SGD.
1032
+
1033
+
1034
+ -----
1035
+
ai_scientist/fewshot_examples/attention.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "review": "{\n \"Summary\": \"The paper proposes the Transformer, a novel neural network architecture that relies entirely on self-attention mechanisms, eschewing traditional recurrent and convolutional layers. This innovation allows the model to achieve state-of-the-art results in machine translation tasks with significant improvements in both training efficiency and translation quality. The paper includes detailed descriptions of the model architecture, including multi-head attention and positional encodings, as well as extensive experimental results to validate the model's performance.\",\n \"Questions\": [\n \"Could the authors provide more detailed comparisons with other recent models not included in Table 2?\",\n \"What is the impact of varying the number of layers (N) in both the encoder and decoder stacks?\",\n \"Can the authors provide more insights into the choice of hyperparameters, especially the learning rate schedule and warmup steps?\"\n ],\n \"Limitations\": [\n \"The paper does not explore the application of the Transformer to tasks beyond machine translation, such as image or audio processing.\",\n \"The discussion on the potential negative societal impacts of the model is minimal and could be expanded.\"\n ],\n \"Ethical Concerns\": false,\n \"Soundness\": 4,\n \"Presentation\": 3,\n \"Contribution\": 4,\n \"Overall\": 8,\n \"Confidence\": 5,\n \"Strengths\": [\n \"The Transformer model introduces a highly innovative use of self-attention mechanisms, replacing traditional recurrent and convolutional layers.\",\n \"Comprehensive experimental validation showing state-of-the-art performance in machine translation tasks.\",\n \"Clear and detailed description of the model architecture and its components, facilitating reproducibility and further research.\"\n ],\n \"Weaknesses\": [\n \"Limited discussion on the application of the model to other domains beyond machine translation.\",\n \"The paper could benefit from a deeper analysis of the potential negative societal impacts of the model.\"\n ],\n \"Originality\": 4,\n \"Quality\": 4,\n \"Clarity\": 4,\n \"Significance\": 4,\n \"Decision\": \"Accept\"\n}"
3
+ }
ai_scientist/fewshot_examples/attention.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d87d482d5ae7960e2e43d7dd6d21377e60e73e8fce1bf2a01aff7aca8a08c537
3
+ size 569417
ai_scientist/fewshot_examples/attention.txt ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Attention Is All You Need
2
+
3
+
4
+ **Ashish Vaswani[∗]**
5
+ Google Brain
6
+ ```
7
+ avaswani@google.com
8
+
9
+ ```
10
+ **Llion Jones[∗]**
11
+ Google Research
12
+ ```
13
+ llion@google.com
14
+
15
+ ```
16
+
17
+ **Noam Shazeer[∗]**
18
+ Google Brain
19
+ ```
20
+ noam@google.com
21
+
22
+ ```
23
+
24
+ **Niki Parmar[∗]**
25
+ Google Research
26
+ ```
27
+ nikip@google.com
28
+
29
+ ```
30
+
31
+ **Jakob Uszkoreit[∗]**
32
+ Google Research
33
+ ```
34
+ usz@google.com
35
+
36
+ ```
37
+
38
+ **Aidan N. Gomez[∗†]**
39
+ University of Toronto
40
+ ```
41
+ aidan@cs.toronto.edu
42
+
43
+ ```
44
+
45
+ **Łukasz Kaiser[∗]**
46
+ Google Brain
47
+ ```
48
+ lukaszkaiser@google.com
49
+
50
+ ```
51
+
52
+ **Illia Polosukhin[∗‡]**
53
+ ```
54
+ illia.polosukhin@gmail.com
55
+
56
+ ```
57
+ **Abstract**
58
+
59
+ The dominant sequence transduction models are based on complex recurrent or
60
+ convolutional neural networks that include an encoder and a decoder. The best
61
+ performing models also connect the encoder and decoder through an attention
62
+ mechanism. We propose a new simple network architecture, the Transformer,
63
+ based solely on attention mechanisms, dispensing with recurrence and convolutions
64
+ entirely. Experiments on two machine translation tasks show these models to
65
+ be superior in quality while being more parallelizable and requiring significantly
66
+ less time to train. Our model achieves 28.4 BLEU on the WMT 2014 Englishto-German translation task, improving over the existing best results, including
67
+ ensembles, by over 2 BLEU. On the WMT 2014 English-to-French translation task,
68
+ our model establishes a new single-model state-of-the-art BLEU score of 41.0 after
69
+ training for 3.5 days on eight GPUs, a small fraction of the training costs of the
70
+ best models from the literature.
71
+
72
+ **1** **Introduction**
73
+
74
+ Recurrent neural networks, long short-term memory [12] and gated recurrent [7] neural networks
75
+ in particular, have been firmly established as state of the art approaches in sequence modeling and
76
+ transduction problems such as language modeling and machine translation [29, 2, 5]. Numerous
77
+ efforts have since continued to push the boundaries of recurrent language models and encoder-decoder
78
+ architectures [31, 21, 13].
79
+
80
+ _∗Equal contribution. Listing order is random. Jakob proposed replacing RNNs with self-attention and started_
81
+ the effort to evaluate this idea. Ashish, with Illia, designed and implemented the first Transformer models and
82
+ has been crucially involved in every aspect of this work. Noam proposed scaled dot-product attention, multi-head
83
+ attention and the parameter-free position representation and became the other person involved in nearly every
84
+ detail. Niki designed, implemented, tuned and evaluated countless model variants in our original codebase and
85
+ tensor2tensor. Llion also experimented with novel model variants, was responsible for our initial codebase, and
86
+ efficient inference and visualizations. Lukasz and Aidan spent countless long days designing various parts of and
87
+ implementing tensor2tensor, replacing our earlier codebase, greatly improving results and massively accelerating
88
+ our research.
89
+ _†Work performed while at Google Brain._
90
+ _‡Work performed while at Google Research._
91
+
92
+ 31st Conference on Neural Information Processing Systems (NIPS 2017), Long Beach, CA, USA.
93
+
94
+
95
+ -----
96
+
97
+ Recurrent models typically factor computation along the symbol positions of the input and output
98
+ sequences. Aligning the positions to steps in computation time, they generate a sequence of hidden
99
+ states ht, as a function of the previous hidden state ht 1 and the input for position t. This inherently
100
+ _−_
101
+ sequential nature precludes parallelization within training examples, which becomes critical at longer
102
+ sequence lengths, as memory constraints limit batching across examples. Recent work has achieved
103
+ significant improvements in computational efficiency through factorization tricks [18] and conditional
104
+ computation [26], while also improving model performance in case of the latter. The fundamental
105
+ constraint of sequential computation, however, remains.
106
+
107
+ Attention mechanisms have become an integral part of compelling sequence modeling and transduction models in various tasks, allowing modeling of dependencies without regard to their distance in
108
+ the input or output sequences [2, 16]. In all but a few cases [22], however, such attention mechanisms
109
+ are used in conjunction with a recurrent network.
110
+
111
+ In this work we propose the Transformer, a model architecture eschewing recurrence and instead
112
+ relying entirely on an attention mechanism to draw global dependencies between input and output.
113
+ The Transformer allows for significantly more parallelization and can reach a new state of the art in
114
+ translation quality after being trained for as little as twelve hours on eight P100 GPUs.
115
+
116
+ **2** **Background**
117
+
118
+ The goal of reducing sequential computation also forms the foundation of the Extended Neural GPU
119
+
120
+ [20], ByteNet [15] and ConvS2S [8], all of which use convolutional neural networks as basic building
121
+ block, computing hidden representations in parallel for all input and output positions. In these models,
122
+ the number of operations required to relate signals from two arbitrary input or output positions grows
123
+ in the distance between positions, linearly for ConvS2S and logarithmically for ByteNet. This makes
124
+ it more difficult to learn dependencies between distant positions [11]. In the Transformer this is
125
+ reduced to a constant number of operations, albeit at the cost of reduced effective resolution due
126
+ to averaging attention-weighted positions, an effect we counteract with Multi-Head Attention as
127
+ described in section 3.2.
128
+
129
+ Self-attention, sometimes called intra-attention is an attention mechanism relating different positions
130
+ of a single sequence in order to compute a representation of the sequence. Self-attention has been
131
+ used successfully in a variety of tasks including reading comprehension, abstractive summarization,
132
+ textual entailment and learning task-independent sentence representations [4, 22, 23, 19].
133
+
134
+ End-to-end memory networks are based on a recurrent attention mechanism instead of sequencealigned recurrence and have been shown to perform well on simple-language question answering and
135
+ language modeling tasks [28].
136
+
137
+ To the best of our knowledge, however, the Transformer is the first transduction model relying
138
+ entirely on self-attention to compute representations of its input and output without using sequencealigned RNNs or convolution. In the following sections, we will describe the Transformer, motivate
139
+ self-attention and discuss its advantages over models such as [14, 15] and [8].
140
+
141
+ **3** **Model Architecture**
142
+
143
+ Most competitive neural sequence transduction models have an encoder-decoder structure [5, 2, 29].
144
+ Here, the encoder maps an input sequence of symbol representations (x1, ..., xn) to a sequence
145
+ of continuous representations z = (z1, ..., zn). Given z, the decoder then generates an output
146
+ sequence (y1, ..., ym) of symbols one element at a time. At each step the model is auto-regressive
147
+
148
+ [9], consuming the previously generated symbols as additional input when generating the next.
149
+
150
+ The Transformer follows this overall architecture using stacked self-attention and point-wise, fully
151
+ connected layers for both the encoder and decoder, shown in the left and right halves of Figure 1,
152
+ respectively.
153
+
154
+ **3.1** **Encoder and Decoder Stacks**
155
+
156
+ **Encoder:** The encoder is composed of a stack of N = 6 identical layers. Each layer has two
157
+ sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position
158
+
159
+ -----
160
+
161
+ Figure 1: The Transformer - model architecture.
162
+
163
+ wise fully connected feed-forward network. We employ a residual connection [10] around each of
164
+ the two sub-layers, followed by layer normalization [1]. That is, the output of each sub-layer is
165
+ LayerNorm(x + Sublayer(x)), where Sublayer(x) is the function implemented by the sub-layer
166
+ itself. To facilitate these residual connections, all sub-layers in the model, as well as the embedding
167
+ layers, produce outputs of dimension dmodel = 512.
168
+
169
+ **Decoder:** The decoder is also composed of a stack of N = 6 identical layers. In addition to the two
170
+ sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head
171
+ attention over the output of the encoder stack. Similar to the encoder, we employ residual connections
172
+ around each of the sub-layers, followed by layer normalization. We also modify the self-attention
173
+ sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This
174
+ masking, combined with fact that the output embeddings are offset by one position, ensures that the
175
+ predictions for position i can depend only on the known outputs at positions less than i.
176
+
177
+ **3.2** **Attention**
178
+
179
+ An attention function can be described as mapping a query and a set of key-value pairs to an output,
180
+ where the query, keys, values, and output are all vectors. The output is computed as a weighted sum
181
+ of the values, where the weight assigned to each value is computed by a compatibility function of the
182
+ query with the corresponding key.
183
+
184
+ **3.2.1** **Scaled Dot-Product Attention**
185
+
186
+ We call our particular attention "Scaled Dot-Product Attention" (Figure 2). The input consists of
187
+ queries and keys of dimension dk, and values of dimension dv. We compute the dot products of the
188
+
189
+
190
+ -----
191
+
192
+ Scaled Dot-Product Attention Multi-Head Attention
193
+
194
+ Figure 2: (left) Scaled Dot-Product Attention. (right) Multi-Head Attention consists of several
195
+ attention layers running in parallel.
196
+
197
+ query with all keys, divide each by _dk, and apply a softmax function to obtain the weights on the_
198
+
199
+ _[√]_
200
+ values.
201
+
202
+ In practice, we compute the attention function on a set of queries simultaneously, packed together
203
+ into a matrix Q. The keys and values are also packed together into matrices K and V . We compute
204
+ the matrix of outputs as:
205
+
206
+ Attention(Q, K, V ) = softmax( _[QK]√dk[T]_ )V (1)
207
+
208
+ The two most commonly used attention functions are additive attention [2], and dot-product (multiplicative) attention. Dot-product attention is identical to our algorithm, except for the scaling factor
209
+ of _√1dk . Additive attention computes the compatibility function using a feed-forward network with_
210
+ a single hidden layer. While the two are similar in theoretical complexity, dot-product attention is
211
+ much faster and more space-efficient in practice, since it can be implemented using highly optimized
212
+ matrix multiplication code.
213
+
214
+ While for small values of dk the two mechanisms perform similarly, additive attention outperforms
215
+ dot product attention without scaling for larger values of dk [3]. We suspect that for large values of
216
+ _dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has_
217
+ extremely small gradients [4]. To counteract this effect, we scale the dot products by _√1dk ._
218
+
219
+ **3.2.2** **Multi-Head Attention**
220
+
221
+ Instead of performing a single attention function with dmodel-dimensional keys, values and queries,
222
+ we found it beneficial to linearly project the queries, keys and values h times with different, learned
223
+ linear projections to dk, dk and dv dimensions, respectively. On each of these projected versions of
224
+ queries, keys and values we then perform the attention function in parallel, yielding dv-dimensional
225
+ output values. These are concatenated and once again projected, resulting in the final values, as
226
+ depicted in Figure 2.
227
+
228
+ Multi-head attention allows the model to jointly attend to information from different representation
229
+ subspaces at different positions. With a single attention head, averaging inhibits this.
230
+
231
+ 4To illustrate why the dot products get large, assume that the components of q and k are independent random
232
+ variables with mean 0 and variance 1. Then their dot product, q · k = _i=1_ _[q][i][k][i][, has mean][ 0][ and variance][ d][k][.]_
233
+
234
+ [P][d][k]
235
+
236
+
237
+ -----
238
+
239
+ MultiHead(Q, K, V ) = Concat(head1, ..., headh)W _[O]_
240
+
241
+ where headi = Attention(QWi[Q][, KW][ K]i _[, V W][ V]i_ [)]
242
+
243
+ Where the projections are parameter matrices Wi[Q] R[d][model][×][d][k], Wi[K] R[d][model][×][d][k], Wi[V] R[d][model][×][d][v]
244
+ _∈_ _∈_ _∈_
245
+ and W _[O]_ _∈_ R[hd][v][×][d][model].
246
+
247
+ In this work we employ h = 8 parallel attention layers, or heads. For each of these we use
248
+ _dk = dv = dmodel/h = 64. Due to the reduced dimension of each head, the total computational cost_
249
+ is similar to that of single-head attention with full dimensionality.
250
+
251
+ **3.2.3** **Applications of Attention in our Model**
252
+
253
+ The Transformer uses multi-head attention in three different ways:
254
+
255
+ _• In "encoder-decoder attention" layers, the queries come from the previous decoder layer,_
256
+ and the memory keys and values come from the output of the encoder. This allows every
257
+ position in the decoder to attend over all positions in the input sequence. This mimics the
258
+ typical encoder-decoder attention mechanisms in sequence-to-sequence models such as
259
+
260
+ [31, 2, 8].
261
+
262
+ _• The encoder contains self-attention layers. In a self-attention layer all of the keys, values_
263
+ and queries come from the same place, in this case, the output of the previous layer in the
264
+ encoder. Each position in the encoder can attend to all positions in the previous layer of the
265
+ encoder.
266
+
267
+ _• Similarly, self-attention layers in the decoder allow each position in the decoder to attend to_
268
+ all positions in the decoder up to and including that position. We need to prevent leftward
269
+ information flow in the decoder to preserve the auto-regressive property. We implement this
270
+ inside of scaled dot-product attention by masking out (setting to −∞) all values in the input
271
+ of the softmax which correspond to illegal connections. See Figure 2.
272
+
273
+ **3.3** **Position-wise Feed-Forward Networks**
274
+
275
+ In addition to attention sub-layers, each of the layers in our encoder and decoder contains a fully
276
+ connected feed-forward network, which is applied to each position separately and identically. This
277
+ consists of two linear transformations with a ReLU activation in between.
278
+
279
+ FFN(x) = max(0, xW1 + b1)W2 + b2 (2)
280
+
281
+ While the linear transformations are the same across different positions, they use different parameters
282
+ from layer to layer. Another way of describing this is as two convolutions with kernel size 1.
283
+ The dimensionality of input and output is dmodel = 512, and the inner-layer has dimensionality
284
+ _dff = 2048._
285
+
286
+ **3.4** **Embeddings and Softmax**
287
+
288
+ Similarly to other sequence transduction models, we use learned embeddings to convert the input
289
+ tokens and output tokens to vectors of dimension dmodel. We also use the usual learned linear transformation and softmax function to convert the decoder output to predicted next-token probabilities. In
290
+ our model, we share the same weight matrix between the two embedding layers and the pre-softmax
291
+ linear transformation, similar to [24]. In the embedding layers, we multiply those weights by _dmodel._
292
+
293
+ _[√]_
294
+
295
+ **3.5** **Positional Encoding**
296
+
297
+ Since our model contains no recurrence and no convolution, in order for the model to make use of the
298
+ order of the sequence, we must inject some information about the relative or absolute position of the
299
+ tokens in the sequence. To this end, we add "positional encodings" to the input embeddings at the
300
+
301
+
302
+ -----
303
+
304
+ Table 1: Maximum path lengths, per-layer complexity and minimum number of sequential operations
305
+ for different layer types. n is the sequence length, d is the representation dimension, k is the kernel
306
+ size of convolutions and r the size of the neighborhood in restricted self-attention.
307
+
308
+ Layer Type Complexity per Layer Sequential Maximum Path Length
309
+ Operations
310
+
311
+ Self-Attention _O(n[2]_ _· d)_ _O(1)_ _O(1)_
312
+ Recurrent _O(n · d[2])_ _O(n)_ _O(n)_
313
+ Convolutional _O(k_ _n_ _d[2])_ _O(1)_ _O(logk(n))_
314
+ _·_ _·_
315
+ Self-Attention (restricted) _O(r · n · d)_ _O(1)_ _O(n/r)_
316
+
317
+ bottoms of the encoder and decoder stacks. The positional encodings have the same dimension dmodel
318
+ as the embeddings, so that the two can be summed. There are many choices of positional encodings,
319
+ learned and fixed [8].
320
+
321
+ In this work, we use sine and cosine functions of different frequencies:
322
+
323
+ _PE(pos,2i) = sin(pos/10000[2][i/d][model])_
324
+
325
+ _PE(pos,2i+1) = cos(pos/10000[2][i/d][model])_
326
+
327
+ where pos is the position and i is the dimension. That is, each dimension of the positional encoding
328
+ corresponds to a sinusoid. The wavelengths form a geometric progression from 2π to 10000 · 2π. We
329
+ chose this function because we hypothesized it would allow the model to easily learn to attend by
330
+ relative positions, since for any fixed offset k, PEpos+k can be represented as a linear function of
331
+ _PEpos._
332
+
333
+ We also experimented with using learned positional embeddings [8] instead, and found that the two
334
+ versions produced nearly identical results (see Table 3 row (E)). We chose the sinusoidal version
335
+ because it may allow the model to extrapolate to sequence lengths longer than the ones encountered
336
+ during training.
337
+
338
+ **4** **Why Self-Attention**
339
+
340
+ In this section we compare various aspects of self-attention layers to the recurrent and convolutional layers commonly used for mapping one variable-length sequence of symbol representations
341
+ (layer in a typical sequence transduction encoder or decoder. Motivating our use of self-attention wex1, ..., xn) to another sequence of equal length (z1, ..., zn), with xi, zi ∈ R[d], such as a hidden
342
+ consider three desiderata.
343
+
344
+ One is the total computational complexity per layer. Another is the amount of computation that can
345
+ be parallelized, as measured by the minimum number of sequential operations required.
346
+
347
+ The third is the path length between long-range dependencies in the network. Learning long-range
348
+ dependencies is a key challenge in many sequence transduction tasks. One key factor affecting the
349
+ ability to learn such dependencies is the length of the paths forward and backward signals have to
350
+ traverse in the network. The shorter these paths between any combination of positions in the input
351
+ and output sequences, the easier it is to learn long-range dependencies [11]. Hence we also compare
352
+ the maximum path length between any two input and output positions in networks composed of the
353
+ different layer types.
354
+
355
+ As noted in Table 1, a self-attention layer connects all positions with a constant number of sequentially
356
+ executed operations, whereas a recurrent layer requires O(n) sequential operations. In terms of
357
+ computational complexity, self-attention layers are faster than recurrent layers when the sequence
358
+ length n is smaller than the representation dimensionality d, which is most often the case with
359
+ sentence representations used by state-of-the-art models in machine translations, such as word-piece
360
+
361
+ [31] and byte-pair [25] representations. To improve computational performance for tasks involving
362
+ very long sequences, self-attention could be restricted to considering only a neighborhood of size r in
363
+
364
+
365
+ -----
366
+
367
+ the input sequence centered around the respective output position. This would increase the maximum
368
+ path length to O(n/r). We plan to investigate this approach further in future work.
369
+
370
+ A single convolutional layer with kernel width k < n does not connect all pairs of input and output
371
+ positions. Doing so requires a stack of O(n/k) convolutional layers in the case of contiguous kernels,
372
+ or O(logk(n)) in the case of dilated convolutions [15], increasing the length of the longest paths
373
+ between any two positions in the network. Convolutional layers are generally more expensive than
374
+ recurrent layers, by a factor of k. Separable convolutions [6], however, decrease the complexity
375
+ considerably, to O(k · n · d + n · d[2]). Even with k = n, however, the complexity of a separable
376
+ convolution is equal to the combination of a self-attention layer and a point-wise feed-forward layer,
377
+ the approach we take in our model.
378
+
379
+ As side benefit, self-attention could yield more interpretable models. We inspect attention distributions
380
+ from our models and present and discuss examples in the appendix. Not only do individual attention
381
+ heads clearly learn to perform different tasks, many appear to exhibit behavior related to the syntactic
382
+ and semantic structure of the sentences.
383
+
384
+ **5** **Training**
385
+
386
+ This section describes the training regime for our models.
387
+
388
+ **5.1** **Training Data and Batching**
389
+
390
+ We trained on the standard WMT 2014 English-German dataset consisting of about 4.5 million
391
+ sentence pairs. Sentences were encoded using byte-pair encoding [3], which has a shared sourcetarget vocabulary of about 37000 tokens. For English-French, we used the significantly larger WMT
392
+ 2014 English-French dataset consisting of 36M sentences and split tokens into a 32000 word-piece
393
+ vocabulary [31]. Sentence pairs were batched together by approximate sequence length. Each training
394
+ batch contained a set of sentence pairs containing approximately 25000 source tokens and 25000
395
+ target tokens.
396
+
397
+ **5.2** **Hardware and Schedule**
398
+
399
+ We trained our models on one machine with 8 NVIDIA P100 GPUs. For our base models using
400
+ the hyperparameters described throughout the paper, each training step took about 0.4 seconds. We
401
+ trained the base models for a total of 100,000 steps or 12 hours. For our big models,(described on the
402
+ bottom line of table 3), step time was 1.0 seconds. The big models were trained for 300,000 steps
403
+ (3.5 days).
404
+
405
+ **5.3** **Optimizer**
406
+
407
+ We used the Adam optimizer [17] with β1 = 0.9, β2 = 0.98 and ϵ = 10[−][9]. We varied the learning
408
+ rate over the course of training, according to the formula:
409
+
410
+ _lrate = d[−]model[0][.][5]_ (3)
411
+
412
+ _[·][ min(][step][_][num][−][0][.][5][, step][_][num][ ·][ warmup][_][steps][−][1][.][5][)]_
413
+
414
+ This corresponds to increasing the learning rate linearly for the first warmup_steps training steps,
415
+ and decreasing it thereafter proportionally to the inverse square root of the step number. We used
416
+ _warmup_steps = 4000._
417
+
418
+ **5.4** **Regularization**
419
+
420
+ We employ three types of regularization during training:
421
+
422
+ **Residual Dropout** We apply dropout [27] to the output of each sub-layer, before it is added to the
423
+ sub-layer input and normalized. In addition, we apply dropout to the sums of the embeddings and the
424
+ positional encodings in both the encoder and decoder stacks. For the base model, we use a rate of
425
+ _Pdrop = 0.1._
426
+
427
+
428
+ -----
429
+
430
+ Table 2: The Transformer achieves better BLEU scores than previous state-of-the-art models on the
431
+ English-to-German and English-to-French newstest2014 tests at a fraction of the training cost.
432
+
433
+ BLEU Training Cost (FLOPs)
434
+ Model
435
+
436
+ EN-DE EN-FR EN-DE EN-FR
437
+
438
+ ByteNet [15] 23.75
439
+ Deep-Att + PosUnk [32] 39.2 1.0 · 10[20]
440
+
441
+ GNMT + RL [31] 24.6 39.92 2.3 · 10[19] 1.4 · 10[20]
442
+
443
+ ConvS2S [8] 25.16 40.46 9.6 · 10[18] 1.5 · 10[20]
444
+
445
+ MoE [26] 26.03 40.56 2.0 · 10[19] 1.2 · 10[20]
446
+
447
+ Deep-Att + PosUnk Ensemble [32] 40.4 8.0 · 10[20]
448
+
449
+ GNMT + RL Ensemble [31] 26.30 41.16 1.8 · 10[20] 1.1 · 10[21]
450
+
451
+ ConvS2S Ensemble [8] 26.36 **41.29** 7.7 · 10[19] 1.2 · 10[21]
452
+
453
+ Transformer (base model) 27.3 38.1 **3.3 · 10[18]**
454
+
455
+ Transformer (big) **28.4** **41.0** 2.3 · 10[19]
456
+
457
+
458
+ **Label Smoothing** During training, we employed label smoothing of value ϵls = 0.1 [30]. This
459
+ hurts perplexity, as the model learns to be more unsure, but improves accuracy and BLEU score.
460
+
461
+ **6** **Results**
462
+
463
+ **6.1** **Machine Translation**
464
+
465
+ On the WMT 2014 English-to-German translation task, the big transformer model (Transformer (big)
466
+ in Table 2) outperforms the best previously reported models (including ensembles) by more than 2.0
467
+ BLEU, establishing a new state-of-the-art BLEU score of 28.4. The configuration of this model is
468
+ listed in the bottom line of Table 3. Training took 3.5 days on 8 P100 GPUs. Even our base model
469
+ surpasses all previously published models and ensembles, at a fraction of the training cost of any of
470
+ the competitive models.
471
+
472
+ On the WMT 2014 English-to-French translation task, our big model achieves a BLEU score of 41.0,
473
+ outperforming all of the previously published single models, at less than 1/4 the training cost of the
474
+ previous state-of-the-art model. The Transformer (big) model trained for English-to-French used
475
+ dropout rate Pdrop = 0.1, instead of 0.3.
476
+
477
+ For the base models, we used a single model obtained by averaging the last 5 checkpoints, which
478
+ were written at 10-minute intervals. For the big models, we averaged the last 20 checkpoints. We
479
+ used beam search with a beam size of 4 and length penalty α = 0.6 [31]. These hyperparameters
480
+ were chosen after experimentation on the development set. We set the maximum output length during
481
+ inference to input length + 50, but terminate early when possible [31].
482
+
483
+ Table 2 summarizes our results and compares our translation quality and training costs to other model
484
+ architectures from the literature. We estimate the number of floating point operations used to train a
485
+ model by multiplying the training time, the number of GPUs used, and an estimate of the sustained
486
+ single-precision floating-point capacity of each GPU [5].
487
+
488
+ **6.2** **Model Variations**
489
+
490
+ To evaluate the importance of different components of the Transformer, we varied our base model
491
+ in different ways, measuring the change in performance on English-to-German translation on the
492
+ development set, newstest2013. We used beam search as described in the previous section, but no
493
+ checkpoint averaging. We present these results in Table 3.
494
+
495
+ In Table 3 rows (A), we vary the number of attention heads and the attention key and value dimensions,
496
+ keeping the amount of computation constant, as described in Section 3.2.2. While single-head
497
+ attention is 0.9 BLEU worse than the best setting, quality also drops off with too many heads.
498
+
499
+ 5We used values of 2.8, 3.7, 6.0 and 9.5 TFLOPS for K80, K40, M40 and P100, respectively.
500
+
501
+
502
+ -----
503
+
504
+ Table 3: Variations on the Transformer architecture. Unlisted values are identical to those of the base
505
+ model. All metrics are on the English-to-German translation development set, newstest2013. Listed
506
+ perplexities are per-wordpiece, according to our byte-pair encoding, and should not be compared to
507
+ per-word perplexities.
508
+
509
+ |Col1|train N d d h d d P ϵ model ff k v drop ls steps|PPL BLEU params (dev) (dev) 106 ×|
510
+ |---|---|---|
511
+ |base|6 512 2048 8 64 64 0.1 0.1 100K|4.92 25.8 65|
512
+ |(A)|1 512 512 4 128 128 16 32 32 32 16 16|5.29 24.9 5.00 25.5 4.91 25.8 5.01 25.4|
513
+ |(B)|16 32|5.16 25.1 58 5.01 25.4 60|
514
+ |(C)|2 4 8 256 32 32 1024 128 128 1024 4096|6.11 23.7 36 5.19 25.3 50 4.88 25.5 80 5.75 24.5 28 4.66 26.0 168 5.12 25.4 53 4.75 26.2 90|
515
+ |(D)|0.0 0.2 0.0 0.2|5.77 24.6 4.95 25.5 4.67 25.3 5.47 25.7|
516
+ |(E)|positional embedding instead of sinusoids|4.92 25.7|
517
+ |big|6 1024 4096 16 0.3 300K|4.33 26.4 213|
518
+
519
+
520
+
521
+ In Table 3 rows (B), we observe that reducing the attention key size dk hurts model quality. This
522
+ suggests that determining compatibility is not easy and that a more sophisticated compatibility
523
+ function than dot product may be beneficial. We further observe in rows (C) and (D) that, as expected,
524
+ bigger models are better, and dropout is very helpful in avoiding over-fitting. In row (E) we replace our
525
+ sinusoidal positional encoding with learned positional embeddings [8], and observe nearly identical
526
+ results to the base model.
527
+
528
+ **7** **Conclusion**
529
+
530
+ In this work, we presented the Transformer, the first sequence transduction model based entirely on
531
+ attention, replacing the recurrent layers most commonly used in encoder-decoder architectures with
532
+ multi-headed self-attention.
533
+
534
+ For translation tasks, the Transformer can be trained significantly faster than architectures based
535
+ on recurrent or convolutional layers. On both WMT 2014 English-to-German and WMT 2014
536
+ English-to-French translation tasks, we achieve a new state of the art. In the former task our best
537
+ model outperforms even all previously reported ensembles.
538
+
539
+ We are excited about the future of attention-based models and plan to apply them to other tasks. We
540
+ plan to extend the Transformer to problems involving input and output modalities other than text and
541
+ to investigate local, restricted attention mechanisms to efficiently handle large inputs and outputs
542
+ such as images, audio and video. Making generation less sequential is another research goals of ours.
543
+
544
+ [The code we used to train and evaluate our models is available at https://github.com/](https://github.com/tensorflow/tensor2tensor)
545
+ ```
546
+ tensorflow/tensor2tensor.
547
+
548
+ ```
549
+ **Acknowledgements** We are grateful to Nal Kalchbrenner and Stephan Gouws for their fruitful
550
+ comments, corrections and inspiration.
551
+
552
+
553
+ -----
554
+
555
+ **References**
556
+
557
+ [1] Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint
558
+ _arXiv:1607.06450, 2016._
559
+
560
+ [2] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly
561
+ learning to align and translate. CoRR, abs/1409.0473, 2014.
562
+
563
+ [3] Denny Britz, Anna Goldie, Minh-Thang Luong, and Quoc V. Le. Massive exploration of neural
564
+ machine translation architectures. CoRR, abs/1703.03906, 2017.
565
+
566
+ [4] Jianpeng Cheng, Li Dong, and Mirella Lapata. Long short-term memory-networks for machine
567
+ reading. arXiv preprint arXiv:1601.06733, 2016.
568
+
569
+ [5] Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Fethi Bougares, Holger Schwenk,
570
+ and Yoshua Bengio. Learning phrase representations using rnn encoder-decoder for statistical
571
+ machine translation. CoRR, abs/1406.1078, 2014.
572
+
573
+ [6] Francois Chollet. Xception: Deep learning with depthwise separable convolutions. arXiv
574
+ _preprint arXiv:1610.02357, 2016._
575
+
576
+ [7] Junyoung Chung, Çaglar Gülçehre, Kyunghyun Cho, and Yoshua Bengio. Empirical evaluation
577
+ of gated recurrent neural networks on sequence modeling. CoRR, abs/1412.3555, 2014.
578
+
579
+ [8] Jonas Gehring, Michael Auli, David Grangier, Denis Yarats, and Yann N. Dauphin. Convolutional sequence to sequence learning. arXiv preprint arXiv:1705.03122v2, 2017.
580
+
581
+ [9] Alex Graves. Generating sequences with recurrent neural networks. _arXiv preprint_
582
+ _arXiv:1308.0850, 2013._
583
+
584
+ [10] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern
585
+ _Recognition, pages 770–778, 2016._
586
+
587
+ [11] Sepp Hochreiter, Yoshua Bengio, Paolo Frasconi, and Jürgen Schmidhuber. Gradient flow in
588
+ recurrent nets: the difficulty of learning long-term dependencies, 2001.
589
+
590
+ [12] Sepp Hochreiter and Jürgen Schmidhuber. Long short-term memory. Neural computation,
591
+ 9(8):1735–1780, 1997.
592
+
593
+ [13] Rafal Jozefowicz, Oriol Vinyals, Mike Schuster, Noam Shazeer, and Yonghui Wu. Exploring
594
+ the limits of language modeling. arXiv preprint arXiv:1602.02410, 2016.
595
+
596
+ [14] Łukasz Kaiser and Ilya Sutskever. Neural GPUs learn algorithms. In International Conference
597
+ _on Learning Representations (ICLR), 2016._
598
+
599
+ [15] Nal Kalchbrenner, Lasse Espeholt, Karen Simonyan, Aaron van den Oord, Alex Graves, and Koray Kavukcuoglu. Neural machine translation in linear time. arXiv preprint arXiv:1610.10099v2,
600
+ 2017.
601
+
602
+ [16] Yoon Kim, Carl Denton, Luong Hoang, and Alexander M. Rush. Structured attention networks.
603
+ In International Conference on Learning Representations, 2017.
604
+
605
+ [17] Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In ICLR, 2015.
606
+
607
+ [18] Oleksii Kuchaiev and Boris Ginsburg. Factorization tricks for LSTM networks. arXiv preprint
608
+ _arXiv:1703.10722, 2017._
609
+
610
+ [19] Zhouhan Lin, Minwei Feng, Cicero Nogueira dos Santos, Mo Yu, Bing Xiang, Bowen
611
+ Zhou, and Yoshua Bengio. A structured self-attentive sentence embedding. arXiv preprint
612
+ _arXiv:1703.03130, 2017._
613
+
614
+ [20] Samy Bengio Łukasz Kaiser. Can active memory replace attention? In Advances in Neural
615
+ _Information Processing Systems, (NIPS), 2016._
616
+
617
+
618
+ -----
619
+
620
+ [21] Minh-Thang Luong, Hieu Pham, and Christopher D Manning. Effective approaches to attentionbased neural machine translation. arXiv preprint arXiv:1508.04025, 2015.
621
+
622
+ [22] Ankur Parikh, Oscar Täckström, Dipanjan Das, and Jakob Uszkoreit. A decomposable attention
623
+ model. In Empirical Methods in Natural Language Processing, 2016.
624
+
625
+ [23] Romain Paulus, Caiming Xiong, and Richard Socher. A deep reinforced model for abstractive
626
+ summarization. arXiv preprint arXiv:1705.04304, 2017.
627
+
628
+ [24] Ofir Press and Lior Wolf. Using the output embedding to improve language models. arXiv
629
+ _preprint arXiv:1608.05859, 2016._
630
+
631
+ [25] Rico Sennrich, Barry Haddow, and Alexandra Birch. Neural machine translation of rare words
632
+ with subword units. arXiv preprint arXiv:1508.07909, 2015.
633
+
634
+ [26] Noam Shazeer, Azalia Mirhoseini, Krzysztof Maziarz, Andy Davis, Quoc Le, Geoffrey Hinton,
635
+ and Jeff Dean. Outrageously large neural networks: The sparsely-gated mixture-of-experts
636
+ layer. arXiv preprint arXiv:1701.06538, 2017.
637
+
638
+ [27] Nitish Srivastava, Geoffrey E Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dropout: a simple way to prevent neural networks from overfitting. Journal of Machine
639
+ _Learning Research, 15(1):1929–1958, 2014._
640
+
641
+ [28] Sainbayar Sukhbaatar, arthur szlam, Jason Weston, and Rob Fergus. End-to-end memory
642
+ networks. In C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama, and R. Garnett, editors,
643
+ _Advances in Neural Information Processing Systems 28, pages 2440–2448. Curran Associates,_
644
+ Inc., 2015.
645
+
646
+ [29] Ilya Sutskever, Oriol Vinyals, and Quoc VV Le. Sequence to sequence learning with neural
647
+ networks. In Advances in Neural Information Processing Systems, pages 3104–3112, 2014.
648
+
649
+ [30] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, and Zbigniew Wojna.
650
+ Rethinking the inception architecture for computer vision. CoRR, abs/1512.00567, 2015.
651
+
652
+ [31] Yonghui Wu, Mike Schuster, Zhifeng Chen, Quoc V Le, Mohammad Norouzi, Wolfgang
653
+ Macherey, Maxim Krikun, Yuan Cao, Qin Gao, Klaus Macherey, et al. Google’s neural machine
654
+ translation system: Bridging the gap between human and machine translation. arXiv preprint
655
+ _arXiv:1609.08144, 2016._
656
+
657
+ [32] Jie Zhou, Ying Cao, Xuguang Wang, Peng Li, and Wei Xu. Deep recurrent models with
658
+ fast-forward connections for neural machine translation. CoRR, abs/1606.04199, 2016.
659
+
660
+
661
+ -----
662
+
ai_scientist/generate_ideas.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ import time
5
+ from typing import List, Dict, Union
6
+
7
+ import backoff
8
+ import requests
9
+
10
+ from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS
11
+
12
+ S2_API_KEY = os.getenv("S2_API_KEY")
13
+
14
+ idea_first_prompt = """{task_description}
15
+ <experiment.py>
16
+ {code}
17
+ </experiment.py>
18
+
19
+ Here are the ideas that you have already generated:
20
+
21
+ '''
22
+ {prev_ideas_string}
23
+ '''
24
+
25
+ Come up with the next impactful and creative idea for research experiments and directions you can feasibly investigate with the code provided.
26
+ Note that you will not have access to any additional resources or datasets.
27
+ Make sure any idea is not overfit the specific training dataset or model, and has wider significance.
28
+
29
+ Respond in the following format:
30
+
31
+ THOUGHT:
32
+ <THOUGHT>
33
+
34
+ NEW IDEA JSON:
35
+ ```json
36
+ <JSON>
37
+ ```
38
+
39
+ In <THOUGHT>, first briefly discuss your intuitions and motivations for the idea. Detail your high-level plan, necessary design choices and ideal outcomes of the experiments. Justify how the idea is different from the existing ones.
40
+
41
+ In <JSON>, provide the new idea in JSON format with the following fields:
42
+ - "Name": A shortened descriptor of the idea. Lowercase, no spaces, underscores allowed.
43
+ - "Title": A title for the idea, will be used for the report writing.
44
+ - "Experiment": An outline of the implementation. E.g. which functions need to be added or modified, how results will be obtained, ...
45
+ - "Interestingness": A rating from 1 to 10 (lowest to highest).
46
+ - "Feasibility": A rating from 1 to 10 (lowest to highest).
47
+ - "Novelty": A rating from 1 to 10 (lowest to highest).
48
+
49
+ Be cautious and realistic on your ratings.
50
+ This JSON will be automatically parsed, so ensure the format is precise.
51
+ You will have {num_reflections} rounds to iterate on the idea, but do not need to use them all.
52
+ """
53
+
54
+ idea_reflection_prompt = """Round {current_round}/{num_reflections}.
55
+ In your thoughts, first carefully consider the quality, novelty, and feasibility of the idea you just created.
56
+ Include any other factors that you think are important in evaluating the idea.
57
+ Ensure the idea is clear and concise, and the JSON is the correct format.
58
+ Do not make things overly complicated.
59
+ In the next attempt, try and refine and improve your idea.
60
+ Stick to the spirit of the original idea unless there are glaring issues.
61
+
62
+ Respond in the same format as before:
63
+ THOUGHT:
64
+ <THOUGHT>
65
+
66
+ NEW IDEA JSON:
67
+ ```json
68
+ <JSON>
69
+ ```
70
+
71
+ If there is nothing to improve, simply repeat the previous JSON EXACTLY after the thought and include "I am done" at the end of the thoughts but before the JSON.
72
+ ONLY INCLUDE "I am done" IF YOU ARE MAKING NO MORE CHANGES."""
73
+
74
+
75
+ # GENERATE IDEAS
76
+ def generate_ideas(
77
+ base_dir,
78
+ client,
79
+ model,
80
+ skip_generation=False,
81
+ max_num_generations=20,
82
+ num_reflections=5,
83
+ ):
84
+ if skip_generation:
85
+ # Load existing ideas from file
86
+ try:
87
+ with open(osp.join(base_dir, "ideas.json"), "r") as f:
88
+ ideas = json.load(f)
89
+ print("Loaded existing ideas:")
90
+ for idea in ideas:
91
+ print(idea)
92
+ return ideas
93
+ except FileNotFoundError:
94
+ print("No existing ideas found. Generating new ideas.")
95
+ except json.JSONDecodeError:
96
+ print("Error decoding existing ideas. Generating new ideas.")
97
+
98
+ idea_str_archive = []
99
+ with open(osp.join(base_dir, "seed_ideas.json"), "r") as f:
100
+ seed_ideas = json.load(f)
101
+ for seed_idea in seed_ideas:
102
+ idea_str_archive.append(json.dumps(seed_idea))
103
+
104
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
105
+ code = f.read()
106
+
107
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
108
+ prompt = json.load(f)
109
+
110
+ idea_system_prompt = prompt["system"]
111
+
112
+ for _ in range(max_num_generations):
113
+ print()
114
+ print(f"Generating idea {_ + 1}/{max_num_generations}")
115
+ try:
116
+ prev_ideas_string = "\n\n".join(idea_str_archive)
117
+
118
+ msg_history = []
119
+ print(f"Iteration 1/{num_reflections}")
120
+ text, msg_history = get_response_from_llm(
121
+ idea_first_prompt.format(
122
+ task_description=prompt["task_description"],
123
+ code=code,
124
+ prev_ideas_string=prev_ideas_string,
125
+ num_reflections=num_reflections,
126
+ ),
127
+ client=client,
128
+ model=model,
129
+ system_message=idea_system_prompt,
130
+ msg_history=msg_history,
131
+ )
132
+ ## PARSE OUTPUT
133
+ json_output = extract_json_between_markers(text)
134
+ assert json_output is not None, "Failed to extract JSON from LLM output"
135
+ print(json_output)
136
+
137
+ # Iteratively improve task.
138
+ if num_reflections > 1:
139
+ for j in range(num_reflections - 1):
140
+ print(f"Iteration {j + 2}/{num_reflections}")
141
+ text, msg_history = get_response_from_llm(
142
+ idea_reflection_prompt.format(
143
+ current_round=j + 2, num_reflections=num_reflections
144
+ ),
145
+ client=client,
146
+ model=model,
147
+ system_message=idea_system_prompt,
148
+ msg_history=msg_history,
149
+ )
150
+ ## PARSE OUTPUT
151
+ json_output = extract_json_between_markers(text)
152
+ assert (
153
+ json_output is not None
154
+ ), "Failed to extract JSON from LLM output"
155
+ print(json_output)
156
+
157
+ if "I am done" in text:
158
+ print(f"Idea generation converged after {j + 2} iterations.")
159
+ break
160
+
161
+ idea_str_archive.append(json.dumps(json_output))
162
+ except Exception as e:
163
+ print(f"Failed to generate idea: {e}")
164
+ continue
165
+
166
+ ## SAVE IDEAS
167
+ ideas = []
168
+ for idea_str in idea_str_archive:
169
+ ideas.append(json.loads(idea_str))
170
+
171
+ with open(osp.join(base_dir, "ideas.json"), "w") as f:
172
+ json.dump(ideas, f, indent=4)
173
+
174
+ return ideas
175
+
176
+
177
+ # GENERATE IDEAS OPEN-ENDED
178
+ def generate_next_idea(
179
+ base_dir,
180
+ client,
181
+ model,
182
+ prev_idea_archive=[],
183
+ num_reflections=5,
184
+ max_attempts=10,
185
+ ):
186
+ idea_archive = prev_idea_archive
187
+ original_archive_size = len(idea_archive)
188
+
189
+ print(f"Generating idea {original_archive_size + 1}")
190
+
191
+ if len(prev_idea_archive) == 0:
192
+ print(f"First iteration, taking seed ideas")
193
+ # seed the archive on the first run with pre-existing ideas
194
+ with open(osp.join(base_dir, "seed_ideas.json"), "r") as f:
195
+ seed_ideas = json.load(f)
196
+ for seed_idea in seed_ideas[:1]:
197
+ idea_archive.append(seed_idea)
198
+ else:
199
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
200
+ code = f.read()
201
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
202
+ prompt = json.load(f)
203
+ idea_system_prompt = prompt["system"]
204
+
205
+ for _ in range(max_attempts):
206
+ try:
207
+ idea_strings = []
208
+ for idea in idea_archive:
209
+ idea_strings.append(json.dumps(idea))
210
+ prev_ideas_string = "\n\n".join(idea_strings)
211
+
212
+ msg_history = []
213
+ print(f"Iteration 1/{num_reflections}")
214
+ text, msg_history = get_response_from_llm(
215
+ idea_first_prompt.format(
216
+ task_description=prompt["task_description"],
217
+ code=code,
218
+ prev_ideas_string=prev_ideas_string,
219
+ num_reflections=num_reflections,
220
+ )
221
+ + """
222
+ Completed ideas have an additional "Score" field which indicates the assessment by an expert ML reviewer.
223
+ This is on a standard 1-10 ML conference scale.
224
+ Scores of 0 indicate the idea failed either during experimentation, writeup or reviewing.
225
+ """,
226
+ client=client,
227
+ model=model,
228
+ system_message=idea_system_prompt,
229
+ msg_history=msg_history,
230
+ )
231
+ ## PARSE OUTPUT
232
+ json_output = extract_json_between_markers(text)
233
+ assert json_output is not None, "Failed to extract JSON from LLM output"
234
+ print(json_output)
235
+
236
+ # Iteratively improve task.
237
+ if num_reflections > 1:
238
+ for j in range(num_reflections - 1):
239
+ print(f"Iteration {j + 2}/{num_reflections}")
240
+ text, msg_history = get_response_from_llm(
241
+ idea_reflection_prompt.format(
242
+ current_round=j + 2, num_reflections=num_reflections
243
+ ),
244
+ client=client,
245
+ model=model,
246
+ system_message=idea_system_prompt,
247
+ msg_history=msg_history,
248
+ )
249
+ ## PARSE OUTPUT
250
+ json_output = extract_json_between_markers(text)
251
+ assert (
252
+ json_output is not None
253
+ ), "Failed to extract JSON from LLM output"
254
+ print(json_output)
255
+
256
+ if "I am done" in text:
257
+ print(
258
+ f"Idea generation converged after {j + 2} iterations."
259
+ )
260
+ break
261
+
262
+ idea_archive.append(json_output)
263
+ break
264
+ except Exception as e:
265
+ print(f"Failed to generate idea: {e}")
266
+ continue
267
+
268
+ ## SAVE IDEAS
269
+ with open(osp.join(base_dir, "ideas.json"), "w") as f:
270
+ json.dump(idea_archive, f, indent=4)
271
+
272
+ return idea_archive
273
+
274
+
275
+ def on_backoff(details):
276
+ print(
277
+ f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries "
278
+ f"calling function {details['target'].__name__} at {time.strftime('%X')}"
279
+ )
280
+
281
+
282
+ @backoff.on_exception(
283
+ backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff
284
+ )
285
+ def search_for_papers(query, result_limit=10, engine="semanticscholar") -> Union[None, List[Dict]]:
286
+ if not query:
287
+ return None
288
+ if engine == "semanticscholar":
289
+ rsp = requests.get(
290
+ "https://api.semanticscholar.org/graph/v1/paper/search",
291
+ headers={"X-API-KEY": S2_API_KEY} if S2_API_KEY else {},
292
+ params={
293
+ "query": query,
294
+ "limit": result_limit,
295
+ "fields": "title,authors,venue,year,abstract,citationStyles,citationCount",
296
+ },
297
+ )
298
+ print(f"Response Status Code: {rsp.status_code}")
299
+ print(
300
+ f"Response Content: {rsp.text[:500]}"
301
+ ) # Print the first 500 characters of the response content
302
+ rsp.raise_for_status()
303
+ results = rsp.json()
304
+ total = results["total"]
305
+ time.sleep(1.0)
306
+ if not total:
307
+ return None
308
+
309
+ papers = results["data"]
310
+ return papers
311
+ elif engine == "openalex":
312
+ import pyalex
313
+ from pyalex import Work, Works
314
+ mail = os.environ.get("OPENALEX_MAIL_ADDRESS", None)
315
+ if mail is None:
316
+ print("[WARNING] Please set OPENALEX_MAIL_ADDRESS for better access to OpenAlex API!")
317
+ else:
318
+ pyalex.config.email = mail
319
+
320
+ def extract_info_from_work(work: Work, max_abstract_length: int = 1000) -> dict[str, str]:
321
+ # "Unknown" is returned when venue is unknown...
322
+ venue = "Unknown"
323
+ for i, location in enumerate(work["locations"]):
324
+ if location["source"] is not None:
325
+ venue = location["source"]["display_name"]
326
+ if venue != "":
327
+ break
328
+ title = work["title"]
329
+ abstract = work["abstract"]
330
+ if abstract is None:
331
+ abstract = ""
332
+ if len(abstract) > max_abstract_length:
333
+ # To avoid context length exceed error.
334
+ print(f"[WARNING] {title=}: {len(abstract)=} is too long! Use first {max_abstract_length} chars.")
335
+ abstract = abstract[:max_abstract_length]
336
+ authors_list = [author["author"]["display_name"] for author in work["authorships"]]
337
+ authors = " and ".join(authors_list) if len(authors_list) < 20 else f"{authors_list[0]} et al."
338
+ paper = dict(
339
+ title=title,
340
+ authors=authors,
341
+ venue=venue,
342
+ year=work["publication_year"],
343
+ abstract=abstract,
344
+ citationCount=work["cited_by_count"],
345
+ )
346
+ return paper
347
+
348
+ works: List[Dict] = Works().search(query).get(per_page=result_limit)
349
+ papers: List[Dict[str, str]] = [extract_info_from_work(work) for work in works]
350
+ return papers
351
+ else:
352
+ raise NotImplementedError(f"{engine=} not supported!")
353
+
354
+
355
+
356
+ novelty_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
357
+ You have an idea and you want to check if it is novel or not. I.e., not overlapping significantly with existing literature or already well explored.
358
+ Be a harsh critic for novelty, ensure there is a sufficient contribution in the idea for a new conference or workshop paper.
359
+ You will be given access to the Semantic Scholar API, which you may use to survey the literature and find relevant papers to help you make your decision.
360
+ The top 10 results for any search query will be presented to you with the abstracts.
361
+
362
+ You will be given {num_rounds} to decide on the paper, but you do not need to use them all.
363
+ At any round, you may exit early and decide on the novelty of the idea.
364
+ Decide a paper idea is novel if after sufficient searching, you have not found a paper that significantly overlaps with your idea.
365
+ Decide a paper idea is not novel, if you have found a paper that significantly overlaps with your idea.
366
+
367
+ {task_description}
368
+ <experiment.py>
369
+ {code}
370
+ </experiment.py>
371
+ """
372
+
373
+ novelty_prompt = '''Round {current_round}/{num_rounds}.
374
+ You have this idea:
375
+
376
+ """
377
+ {idea}
378
+ """
379
+
380
+ The results of the last query are (empty on first round):
381
+ """
382
+ {last_query_results}
383
+ """
384
+
385
+ Respond in the following format:
386
+
387
+ THOUGHT:
388
+ <THOUGHT>
389
+
390
+ RESPONSE:
391
+ ```json
392
+ <JSON>
393
+ ```
394
+
395
+ In <THOUGHT>, first briefly reason over the idea and identify any query that could help you make your decision.
396
+ If you have made your decision, add "Decision made: novel." or "Decision made: not novel." to your thoughts.
397
+
398
+ In <JSON>, respond in JSON format with ONLY the following field:
399
+ - "Query": An optional search query to search the literature (e.g. attention is all you need). You must make a query if you have not decided this round.
400
+
401
+ A query will work best if you are able to recall the exact name of the paper you are looking for, or the authors.
402
+ This JSON will be automatically parsed, so ensure the format is precise.'''
403
+
404
+
405
+ def check_idea_novelty(
406
+ ideas,
407
+ base_dir,
408
+ client,
409
+ model,
410
+ max_num_iterations=10,
411
+ engine="semanticscholar",
412
+ ):
413
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
414
+ code = f.read()
415
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
416
+ prompt = json.load(f)
417
+ task_description = prompt["task_description"]
418
+
419
+ for idx, idea in enumerate(ideas):
420
+ if "novel" in idea:
421
+ print(f"Skipping idea {idx}, already checked.")
422
+ continue
423
+
424
+ print(f"\nChecking novelty of idea {idx}: {idea['Name']}")
425
+
426
+ novel = False
427
+ msg_history = []
428
+ papers_str = ""
429
+
430
+ for j in range(max_num_iterations):
431
+ try:
432
+ text, msg_history = get_response_from_llm(
433
+ novelty_prompt.format(
434
+ current_round=j + 1,
435
+ num_rounds=max_num_iterations,
436
+ idea=idea,
437
+ last_query_results=papers_str,
438
+ ),
439
+ client=client,
440
+ model=model,
441
+ system_message=novelty_system_msg.format(
442
+ num_rounds=max_num_iterations,
443
+ task_description=task_description,
444
+ code=code,
445
+ ),
446
+ msg_history=msg_history,
447
+ )
448
+ if "decision made: novel" in text.lower():
449
+ print("Decision made: novel after round", j)
450
+ novel = True
451
+ break
452
+ if "decision made: not novel" in text.lower():
453
+ print("Decision made: not novel after round", j)
454
+ break
455
+
456
+ ## PARSE OUTPUT
457
+ json_output = extract_json_between_markers(text)
458
+ assert json_output is not None, "Failed to extract JSON from LLM output"
459
+
460
+ ## SEARCH FOR PAPERS
461
+ query = json_output["Query"]
462
+ papers = search_for_papers(query, result_limit=10, engine=engine)
463
+ if papers is None:
464
+ papers_str = "No papers found."
465
+
466
+ paper_strings = []
467
+ for i, paper in enumerate(papers):
468
+ paper_strings.append(
469
+ """{i}: {title}. {authors}. {venue}, {year}.\nNumber of citations: {cites}\nAbstract: {abstract}""".format(
470
+ i=i,
471
+ title=paper["title"],
472
+ authors=paper["authors"],
473
+ venue=paper["venue"],
474
+ year=paper["year"],
475
+ cites=paper["citationCount"],
476
+ abstract=paper["abstract"],
477
+ )
478
+ )
479
+ papers_str = "\n\n".join(paper_strings)
480
+
481
+ except Exception as e:
482
+ print(f"Error: {e}")
483
+ continue
484
+
485
+ idea["novel"] = novel
486
+
487
+ # Save results to JSON file
488
+ results_file = osp.join(base_dir, "ideas.json")
489
+ with open(results_file, "w") as f:
490
+ json.dump(ideas, f, indent=4)
491
+
492
+ return ideas
493
+
494
+
495
+ if __name__ == "__main__":
496
+ MAX_NUM_GENERATIONS = 32
497
+ NUM_REFLECTIONS = 5
498
+ import argparse
499
+
500
+ parser = argparse.ArgumentParser(description="Generate AI scientist ideas")
501
+ # add type of experiment (nanoGPT, Boston, etc.)
502
+ parser.add_argument(
503
+ "--experiment",
504
+ type=str,
505
+ default="nanoGPT",
506
+ help="Experiment to run AI Scientist on.",
507
+ )
508
+ parser.add_argument(
509
+ "--model",
510
+ type=str,
511
+ default="gpt-4o-2024-05-13",
512
+ choices=AVAILABLE_LLMS,
513
+ help="Model to use for AI Scientist.",
514
+ )
515
+ parser.add_argument(
516
+ "--skip-idea-generation",
517
+ action="store_true",
518
+ help="Skip idea generation and use existing ideas.",
519
+ )
520
+ parser.add_argument(
521
+ "--check-novelty",
522
+ action="store_true",
523
+ help="Check novelty of ideas.",
524
+ )
525
+ args = parser.parse_args()
526
+
527
+ # Create client
528
+ client, client_model = create_client(args.model)
529
+
530
+ base_dir = osp.join("templates", args.experiment)
531
+ results_dir = osp.join("results", args.experiment)
532
+ ideas = generate_ideas(
533
+ base_dir,
534
+ client=client,
535
+ model=client_model,
536
+ skip_generation=args.skip_idea_generation,
537
+ max_num_generations=MAX_NUM_GENERATIONS,
538
+ num_reflections=NUM_REFLECTIONS,
539
+ )
540
+ if args.check_novelty:
541
+ ideas = check_idea_novelty(
542
+ ideas,
543
+ base_dir=base_dir,
544
+ client=client,
545
+ model=client_model,
546
+ )
ai_scientist/llm.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+
5
+ import anthropic
6
+ import backoff
7
+ import openai
8
+ import google.generativeai as genai
9
+ from google.generativeai.types import GenerationConfig
10
+
11
+ MAX_NUM_TOKENS = 4096
12
+
13
+ AVAILABLE_LLMS = [
14
+ # Anthropic models
15
+ "claude-3-5-sonnet-20240620",
16
+ "claude-3-5-sonnet-20241022",
17
+ # OpenAI models
18
+ "gpt-4o-mini",
19
+ "gpt-4o-mini-2024-07-18",
20
+ "gpt-4o",
21
+ "gpt-4o-2024-05-13",
22
+ "gpt-4o-2024-08-06",
23
+ "gpt-4.1",
24
+ "gpt-4.1-2025-04-14",
25
+ "gpt-4.1-mini",
26
+ "gpt-4.1-mini-2025-04-14",
27
+ "gpt-4.1-nano",
28
+ "gpt-4.1-nano-2025-04-14",
29
+ "o1",
30
+ "o1-2024-12-17",
31
+ "o1-preview-2024-09-12",
32
+ "o1-mini",
33
+ "o1-mini-2024-09-12",
34
+ "o3-mini",
35
+ "o3-mini-2025-01-31",
36
+ # OpenRouter models
37
+ "llama3.1-405b",
38
+ # Anthropic Claude models via Amazon Bedrock
39
+ "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
40
+ "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
41
+ "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
42
+ "bedrock/anthropic.claude-3-haiku-20240307-v1:0",
43
+ "bedrock/anthropic.claude-3-opus-20240229-v1:0",
44
+ # Anthropic Claude models Vertex AI
45
+ "vertex_ai/claude-3-opus@20240229",
46
+ "vertex_ai/claude-3-5-sonnet@20240620",
47
+ "vertex_ai/claude-3-5-sonnet-v2@20241022",
48
+ "vertex_ai/claude-3-sonnet@20240229",
49
+ "vertex_ai/claude-3-haiku@20240307",
50
+ # DeepSeek models
51
+ "deepseek-chat",
52
+ "deepseek-coder",
53
+ "deepseek-reasoner",
54
+ # Google Gemini models
55
+ "gemini-1.5-flash",
56
+ "gemini-1.5-pro",
57
+ "gemini-2.0-flash",
58
+ "gemini-2.0-flash-lite",
59
+ "gemini-2.0-flash-thinking-exp-01-21",
60
+ "gemini-2.5-pro-preview-03-25",
61
+ "gemini-2.5-pro-exp-03-25",
62
+ ]
63
+
64
+
65
+ # Get N responses from a single message, used for ensembling.
66
+ @backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
67
+ def get_batch_responses_from_llm(
68
+ msg,
69
+ client,
70
+ model,
71
+ system_message,
72
+ print_debug=False,
73
+ msg_history=None,
74
+ temperature=0.75,
75
+ n_responses=1,
76
+ ):
77
+ if msg_history is None:
78
+ msg_history = []
79
+
80
+ if 'gpt' in model:
81
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
82
+ response = client.chat.completions.create(
83
+ model=model,
84
+ messages=[
85
+ {"role": "system", "content": system_message},
86
+ *new_msg_history,
87
+ ],
88
+ temperature=temperature,
89
+ max_tokens=MAX_NUM_TOKENS,
90
+ n=n_responses,
91
+ stop=None,
92
+ seed=0,
93
+ )
94
+ content = [r.message.content for r in response.choices]
95
+ new_msg_history = [
96
+ new_msg_history + [{"role": "assistant", "content": c}] for c in content
97
+ ]
98
+ elif model == "llama-3-1-405b-instruct":
99
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
100
+ response = client.chat.completions.create(
101
+ model="meta-llama/llama-3.1-405b-instruct",
102
+ messages=[
103
+ {"role": "system", "content": system_message},
104
+ *new_msg_history,
105
+ ],
106
+ temperature=temperature,
107
+ max_tokens=MAX_NUM_TOKENS,
108
+ n=n_responses,
109
+ stop=None,
110
+ )
111
+ content = [r.message.content for r in response.choices]
112
+ new_msg_history = [
113
+ new_msg_history + [{"role": "assistant", "content": c}] for c in content
114
+ ]
115
+ else:
116
+ content, new_msg_history = [], []
117
+ for _ in range(n_responses):
118
+ c, hist = get_response_from_llm(
119
+ msg,
120
+ client,
121
+ model,
122
+ system_message,
123
+ print_debug=False,
124
+ msg_history=None,
125
+ temperature=temperature,
126
+ )
127
+ content.append(c)
128
+ new_msg_history.append(hist)
129
+
130
+ if print_debug:
131
+ print()
132
+ print("*" * 20 + " LLM START " + "*" * 20)
133
+ for j, msg in enumerate(new_msg_history[0]):
134
+ print(f'{j}, {msg["role"]}: {msg["content"]}')
135
+ print(content)
136
+ print("*" * 21 + " LLM END " + "*" * 21)
137
+ print()
138
+
139
+ return content, new_msg_history
140
+
141
+
142
+ @backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
143
+ def get_response_from_llm(
144
+ msg,
145
+ client,
146
+ model,
147
+ system_message,
148
+ print_debug=False,
149
+ msg_history=None,
150
+ temperature=0.75,
151
+ ):
152
+ if msg_history is None:
153
+ msg_history = []
154
+
155
+ if "claude" in model:
156
+ new_msg_history = msg_history + [
157
+ {
158
+ "role": "user",
159
+ "content": [
160
+ {
161
+ "type": "text",
162
+ "text": msg,
163
+ }
164
+ ],
165
+ }
166
+ ]
167
+ response = client.messages.create(
168
+ model=model,
169
+ max_tokens=MAX_NUM_TOKENS,
170
+ temperature=temperature,
171
+ system=system_message,
172
+ messages=new_msg_history,
173
+ )
174
+ content = response.content[0].text
175
+ new_msg_history = new_msg_history + [
176
+ {
177
+ "role": "assistant",
178
+ "content": [
179
+ {
180
+ "type": "text",
181
+ "text": content,
182
+ }
183
+ ],
184
+ }
185
+ ]
186
+ elif 'gpt' in model:
187
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
188
+ response = client.chat.completions.create(
189
+ model=model,
190
+ messages=[
191
+ {"role": "system", "content": system_message},
192
+ *new_msg_history,
193
+ ],
194
+ temperature=temperature,
195
+ max_tokens=MAX_NUM_TOKENS,
196
+ n=1,
197
+ stop=None,
198
+ seed=0,
199
+ )
200
+ content = response.choices[0].message.content
201
+ new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
202
+ elif "o1" in model or "o3" in model:
203
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
204
+ response = client.chat.completions.create(
205
+ model=model,
206
+ messages=[
207
+ {"role": "user", "content": system_message},
208
+ *new_msg_history,
209
+ ],
210
+ temperature=1,
211
+ max_completion_tokens=MAX_NUM_TOKENS,
212
+ n=1,
213
+ seed=0,
214
+ )
215
+ content = response.choices[0].message.content
216
+ new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
217
+ elif model in ["meta-llama/llama-3.1-405b-instruct", "llama-3-1-405b-instruct"]:
218
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
219
+ response = client.chat.completions.create(
220
+ model="meta-llama/llama-3.1-405b-instruct",
221
+ messages=[
222
+ {"role": "system", "content": system_message},
223
+ *new_msg_history,
224
+ ],
225
+ temperature=temperature,
226
+ max_tokens=MAX_NUM_TOKENS,
227
+ n=1,
228
+ stop=None,
229
+ )
230
+ content = response.choices[0].message.content
231
+ new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
232
+ elif model in ["deepseek-chat", "deepseek-coder"]:
233
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
234
+ response = client.chat.completions.create(
235
+ model=model,
236
+ messages=[
237
+ {"role": "system", "content": system_message},
238
+ *new_msg_history,
239
+ ],
240
+ temperature=temperature,
241
+ max_tokens=MAX_NUM_TOKENS,
242
+ n=1,
243
+ stop=None,
244
+ )
245
+ content = response.choices[0].message.content
246
+ new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
247
+ elif model in ["deepseek-reasoner"]:
248
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
249
+ response = client.chat.completions.create(
250
+ model=model,
251
+ messages=[
252
+ {"role": "system", "content": system_message},
253
+ *new_msg_history,
254
+ ],
255
+ n=1,
256
+ stop=None,
257
+ )
258
+ content = response.choices[0].message.content
259
+ new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
260
+ elif "gemini" in model:
261
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
262
+ response = client.chat.completions.create(
263
+ model=model,
264
+ messages=[
265
+ {"role": "system", "content": system_message},
266
+ *new_msg_history,
267
+ ],
268
+ temperature=temperature,
269
+ max_tokens=MAX_NUM_TOKENS,
270
+ n=1,
271
+ )
272
+ content = response.choices[0].message.content
273
+ new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
274
+ else:
275
+ raise ValueError(f"Model {model} not supported.")
276
+
277
+ if print_debug:
278
+ print()
279
+ print("*" * 20 + " LLM START " + "*" * 20)
280
+ for j, msg in enumerate(new_msg_history):
281
+ print(f'{j}, {msg["role"]}: {msg["content"]}')
282
+ print(content)
283
+ print("*" * 21 + " LLM END " + "*" * 21)
284
+ print()
285
+
286
+ return content, new_msg_history
287
+
288
+
289
+ def extract_json_between_markers(llm_output):
290
+ # Regular expression pattern to find JSON content between ```json and ```
291
+ json_pattern = r"```json(.*?)```"
292
+ matches = re.findall(json_pattern, llm_output, re.DOTALL)
293
+
294
+ if not matches:
295
+ # Fallback: Try to find any JSON-like content in the output
296
+ json_pattern = r"\{.*?\}"
297
+ matches = re.findall(json_pattern, llm_output, re.DOTALL)
298
+
299
+ for json_string in matches:
300
+ json_string = json_string.strip()
301
+ try:
302
+ parsed_json = json.loads(json_string)
303
+ return parsed_json
304
+ except json.JSONDecodeError:
305
+ # Attempt to fix common JSON issues
306
+ try:
307
+ # Remove invalid control characters
308
+ json_string_clean = re.sub(r"[\x00-\x1F\x7F]", "", json_string)
309
+ parsed_json = json.loads(json_string_clean)
310
+ return parsed_json
311
+ except json.JSONDecodeError:
312
+ continue # Try next match
313
+
314
+ return None # No valid JSON found
315
+
316
+
317
+ def create_client(model):
318
+ if model.startswith("claude-"):
319
+ print(f"Using Anthropic API with model {model}.")
320
+ return anthropic.Anthropic(), model
321
+ elif model.startswith("bedrock") and "claude" in model:
322
+ client_model = model.split("/")[-1]
323
+ print(f"Using Amazon Bedrock with model {client_model}.")
324
+ return anthropic.AnthropicBedrock(), client_model
325
+ elif model.startswith("vertex_ai") and "claude" in model:
326
+ client_model = model.split("/")[-1]
327
+ print(f"Using Vertex AI with model {client_model}.")
328
+ return anthropic.AnthropicVertex(), client_model
329
+ elif 'gpt' in model or "o1" in model or "o3" in model:
330
+ print(f"Using OpenAI API with model {model}.")
331
+ return openai.OpenAI(), model
332
+ elif model in ["deepseek-chat", "deepseek-reasoner", "deepseek-coder"]:
333
+ print(f"Using OpenAI API with {model}.")
334
+ return openai.OpenAI(
335
+ api_key=os.environ["DEEPSEEK_API_KEY"],
336
+ base_url="https://api.deepseek.com"
337
+ ), model
338
+ elif model == "llama3.1-405b":
339
+ print(f"Using OpenAI API with {model}.")
340
+ return openai.OpenAI(
341
+ api_key=os.environ["OPENROUTER_API_KEY"],
342
+ base_url="https://openrouter.ai/api/v1"
343
+ ), "meta-llama/llama-3.1-405b-instruct"
344
+ elif "gemini" in model:
345
+ print(f"Using OpenAI API with {model}.")
346
+ return openai.OpenAI(
347
+ api_key=os.environ["GEMINI_API_KEY"],
348
+ base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
349
+ ), model
350
+ else:
351
+ raise ValueError(f"Model {model} not supported.")
ai_scientist/perform_experiments.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os.path as osp
3
+ import shutil
4
+ import subprocess
5
+ import sys
6
+ from subprocess import TimeoutExpired
7
+
8
+ MAX_ITERS = 4
9
+ MAX_RUNS = 5
10
+ MAX_STDERR_OUTPUT = 1500
11
+
12
+ coder_prompt = """Your goal is to implement the following idea: {title}.
13
+ The proposed experiment is as follows: {idea}.
14
+ You are given a total of up to {max_runs} runs to complete the necessary experiments. You do not need to use all {max_runs}.
15
+
16
+ First, plan the list of experiments you would like to run. For example, if you are sweeping over a specific hyperparameter, plan each value you would like to test for each run.
17
+
18
+ Note that we already provide the vanilla baseline results, so you do not need to re-run it.
19
+
20
+ For reference, the baseline results are as follows:
21
+
22
+ {baseline_results}
23
+
24
+ After you complete each change, we will run the command `python experiment.py --out_dir=run_i' where i is the run number and evaluate the results.
25
+ YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS.
26
+ You can then implement the next thing on your list."""
27
+
28
+
29
+ # RUN EXPERIMENT
30
+ def run_experiment(folder_name, run_num, timeout=7200):
31
+ cwd = osp.abspath(folder_name)
32
+ # COPY CODE SO WE CAN SEE IT.
33
+ shutil.copy(
34
+ osp.join(folder_name, "experiment.py"),
35
+ osp.join(folder_name, f"run_{run_num}.py"),
36
+ )
37
+
38
+ # LAUNCH COMMAND
39
+ command = [
40
+ "python",
41
+ "experiment.py",
42
+ f"--out_dir=run_{run_num}",
43
+ ]
44
+ try:
45
+ result = subprocess.run(
46
+ command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout
47
+ )
48
+
49
+ if result.stderr:
50
+ print(result.stderr, file=sys.stderr)
51
+
52
+ if result.returncode != 0:
53
+ print(f"Run {run_num} failed with return code {result.returncode}")
54
+ if osp.exists(osp.join(cwd, f"run_{run_num}")):
55
+ shutil.rmtree(osp.join(cwd, f"run_{run_num}"))
56
+ print(f"Run failed with the following error {result.stderr}")
57
+ stderr_output = result.stderr
58
+ if len(stderr_output) > MAX_STDERR_OUTPUT:
59
+ stderr_output = "..." + stderr_output[-MAX_STDERR_OUTPUT:]
60
+ next_prompt = f"Run failed with the following error {stderr_output}"
61
+ else:
62
+ with open(osp.join(cwd, f"run_{run_num}", "final_info.json"), "r") as f:
63
+ results = json.load(f)
64
+ results = {k: v["means"] for k, v in results.items()}
65
+
66
+ next_prompt = f"""Run {run_num} completed. Here are the results:
67
+ {results}
68
+
69
+ Decide if you need to re-plan your experiments given the result (you often will not need to).
70
+
71
+ Someone else will be using `notes.txt` to perform a writeup on this in the future.
72
+ Please include *all* relevant information for the writeup on Run {run_num}, including an experiment description and the run number. Be as verbose as necessary.
73
+
74
+ Then, implement the next thing on your list.
75
+ We will then run the command `python experiment.py --out_dir=run_{run_num + 1}'.
76
+ YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS.
77
+ If you are finished with experiments, respond with 'ALL_COMPLETED'."""
78
+ return result.returncode, next_prompt
79
+ except TimeoutExpired:
80
+ print(f"Run {run_num} timed out after {timeout} seconds")
81
+ if osp.exists(osp.join(cwd, f"run_{run_num}")):
82
+ shutil.rmtree(osp.join(cwd, f"run_{run_num}"))
83
+ next_prompt = f"Run timed out after {timeout} seconds"
84
+ return 1, next_prompt
85
+
86
+
87
+ # RUN PLOTTING
88
+ def run_plotting(folder_name, timeout=600):
89
+ cwd = osp.abspath(folder_name)
90
+ # LAUNCH COMMAND
91
+ command = [
92
+ "python",
93
+ "plot.py",
94
+ ]
95
+ try:
96
+ result = subprocess.run(
97
+ command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout
98
+ )
99
+
100
+ if result.stderr:
101
+ print(result.stderr, file=sys.stderr)
102
+
103
+ if result.returncode != 0:
104
+ print(f"Plotting failed with return code {result.returncode}")
105
+ next_prompt = f"Plotting failed with the following error {result.stderr}"
106
+ else:
107
+ next_prompt = ""
108
+ return result.returncode, next_prompt
109
+ except TimeoutExpired:
110
+ print(f"Plotting timed out after {timeout} seconds")
111
+ next_prompt = f"Plotting timed out after {timeout} seconds"
112
+ return 1, next_prompt
113
+
114
+
115
+ # PERFORM EXPERIMENTS
116
+ def perform_experiments(idea, folder_name, coder, baseline_results) -> bool:
117
+ ## RUN EXPERIMENT
118
+ current_iter = 0
119
+ run = 1
120
+ next_prompt = coder_prompt.format(
121
+ title=idea["Title"],
122
+ idea=idea["Experiment"],
123
+ max_runs=MAX_RUNS,
124
+ baseline_results=baseline_results,
125
+ )
126
+ while run < MAX_RUNS + 1:
127
+ if current_iter >= MAX_ITERS:
128
+ print("Max iterations reached")
129
+ break
130
+ coder_out = coder.run(next_prompt)
131
+ print(coder_out)
132
+ if "ALL_COMPLETED" in coder_out:
133
+ break
134
+ return_code, next_prompt = run_experiment(folder_name, run)
135
+ if return_code == 0:
136
+ run += 1
137
+ current_iter = 0
138
+ current_iter += 1
139
+ if current_iter >= MAX_ITERS:
140
+ print("Not all experiments completed.")
141
+ return False
142
+
143
+ current_iter = 0
144
+ next_prompt = """
145
+ Great job! Please modify `plot.py` to generate the most relevant plots for the final writeup.
146
+
147
+ In particular, be sure to fill in the "labels" dictionary with the correct names for each run that you want to plot.
148
+
149
+ Only the runs in the `labels` dictionary will be plotted, so make sure to include all relevant runs.
150
+
151
+ We will be running the command `python plot.py` to generate the plots.
152
+ """
153
+ while True:
154
+ _ = coder.run(next_prompt)
155
+ return_code, next_prompt = run_plotting(folder_name)
156
+ current_iter += 1
157
+ if return_code == 0 or current_iter >= MAX_ITERS:
158
+ break
159
+ next_prompt = """
160
+ Please modify `notes.txt` with a description of what each plot shows along with the filename of the figure. Please do so in-depth.
161
+
162
+ Somebody else will be using `notes.txt` to write a report on this in the future.
163
+ """
164
+ coder.run(next_prompt)
165
+
166
+ return True
ai_scientist/perform_review.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import json
4
+ from pypdf import PdfReader
5
+ import pymupdf
6
+ import pymupdf4llm
7
+ from ai_scientist.llm import (
8
+ get_response_from_llm,
9
+ get_batch_responses_from_llm,
10
+ extract_json_between_markers,
11
+ )
12
+
13
+ reviewer_system_prompt_base = (
14
+ "You are an AI researcher who is reviewing a paper that was submitted to a prestigious ML venue."
15
+ "Be critical and cautious in your decision."
16
+ )
17
+
18
+ reviewer_system_prompt_neg = (
19
+ reviewer_system_prompt_base
20
+ + "If a paper is bad or you are unsure, give it bad scores and reject it."
21
+ )
22
+ reviewer_system_prompt_pos = (
23
+ reviewer_system_prompt_base
24
+ + "If a paper is good or you are unsure, give it good scores and accept it."
25
+ )
26
+
27
+ template_instructions = """
28
+ Respond in the following format:
29
+
30
+ THOUGHT:
31
+ <THOUGHT>
32
+
33
+ REVIEW JSON:
34
+ ```json
35
+ <JSON>
36
+ ```
37
+
38
+ In <THOUGHT>, first briefly discuss your intuitions and reasoning for the evaluation.
39
+ Detail your high-level arguments, necessary choices and desired outcomes of the review.
40
+ Do not make generic comments here, but be specific to your current paper.
41
+ Treat this as the note-taking phase of your review.
42
+
43
+ In <JSON>, provide the review in JSON format with the following fields in the order:
44
+ - "Summary": A summary of the paper content and its contributions.
45
+ - "Strengths": A list of strengths of the paper.
46
+ - "Weaknesses": A list of weaknesses of the paper.
47
+ - "Originality": A rating from 1 to 4 (low, medium, high, very high).
48
+ - "Quality": A rating from 1 to 4 (low, medium, high, very high).
49
+ - "Clarity": A rating from 1 to 4 (low, medium, high, very high).
50
+ - "Significance": A rating from 1 to 4 (low, medium, high, very high).
51
+ - "Questions": A set of clarifying questions to be answered by the paper authors.
52
+ - "Limitations": A set of limitations and potential negative societal impacts of the work.
53
+ - "Ethical Concerns": A boolean value indicating whether there are ethical concerns.
54
+ - "Soundness": A rating from 1 to 4 (poor, fair, good, excellent).
55
+ - "Presentation": A rating from 1 to 4 (poor, fair, good, excellent).
56
+ - "Contribution": A rating from 1 to 4 (poor, fair, good, excellent).
57
+ - "Overall": A rating from 1 to 10 (very strong reject to award quality).
58
+ - "Confidence": A rating from 1 to 5 (low, medium, high, very high, absolute).
59
+ - "Decision": A decision that has to be one of the following: Accept, Reject.
60
+
61
+ For the "Decision" field, don't use Weak Accept, Borderline Accept, Borderline Reject, or Strong Reject. Instead, only use Accept or Reject.
62
+ This JSON will be automatically parsed, so ensure the format is precise.
63
+ """
64
+
65
+ neurips_form = (
66
+ """
67
+ ## Review Form
68
+ Below is a description of the questions you will be asked on the review form for each paper and some guidelines on what to consider when answering these questions.
69
+ When writing your review, please keep in mind that after decisions have been made, reviews and meta-reviews of accepted papers and opted-in rejected papers will be made public.
70
+
71
+ 1. Summary: Briefly summarize the paper and its contributions. This is not the place to critique the paper; the authors should generally agree with a well-written summary.
72
+ - Strengths and Weaknesses: Please provide a thorough assessment of the strengths and weaknesses of the paper, touching on each of the following dimensions:
73
+ - Originality: Are the tasks or methods new? Is the work a novel combination of well-known techniques? (This can be valuable!) Is it clear how this work differs from previous contributions? Is related work adequately cited
74
+ - Quality: Is the submission technically sound? Are claims well supported (e.g., by theoretical analysis or experimental results)? Are the methods used appropriate? Is this a complete piece of work or work in progress? Are the authors careful and honest about evaluating both the strengths and weaknesses of their work
75
+ - Clarity: Is the submission clearly written? Is it well organized? (If not, please make constructive suggestions for improving its clarity.) Does it adequately inform the reader? (Note that a superbly written paper provides enough information for an expert reader to reproduce its results.)
76
+ - Significance: Are the results important? Are others (researchers or practitioners) likely to use the ideas or build on them? Does the submission address a difficult task in a better way than previous work? Does it advance the state of the art in a demonstrable way? Does it provide unique data, unique conclusions about existing data, or a unique theoretical or experimental approach?
77
+
78
+ 2. Questions: Please list up and carefully describe any questions and suggestions for the authors. Think of the things where a response from the author can change your opinion, clarify a confusion or address a limitation. This can be very important for a productive rebuttal and discussion phase with the authors.
79
+
80
+ 3. Limitations: Have the authors adequately addressed the limitations and potential negative societal impact of their work? If not, please include constructive suggestions for improvement.
81
+ In general, authors should be rewarded rather than punished for being up front about the limitations of their work and any potential negative societal impact. You are encouraged to think through whether any critical points are missing and provide these as feedback for the authors.
82
+
83
+ 4. Ethical concerns: If there are ethical issues with this paper, please flag the paper for an ethics review. For guidance on when this is appropriate, please review the NeurIPS ethics guidelines.
84
+
85
+ 5. Soundness: Please assign the paper a numerical rating on the following scale to indicate the soundness of the technical claims, experimental and research methodology and on whether the central claims of the paper are adequately supported with evidence.
86
+ 4: excellent
87
+ 3: good
88
+ 2: fair
89
+ 1: poor
90
+
91
+ 6. Presentation: Please assign the paper a numerical rating on the following scale to indicate the quality of the presentation. This should take into account the writing style and clarity, as well as contextualization relative to prior work.
92
+ 4: excellent
93
+ 3: good
94
+ 2: fair
95
+ 1: poor
96
+
97
+ 7. Contribution: Please assign the paper a numerical rating on the following scale to indicate the quality of the overall contribution this paper makes to the research area being studied. Are the questions being asked important? Does the paper bring a significant originality of ideas and/or execution? Are the results valuable to share with the broader NeurIPS community.
98
+ 4: excellent
99
+ 3: good
100
+ 2: fair
101
+ 1: poor
102
+
103
+ 8. Overall: Please provide an "overall score" for this submission. Choices:
104
+ 10: Award quality: Technically flawless paper with groundbreaking impact on one or more areas of AI, with exceptionally strong evaluation, reproducibility, and resources, and no unaddressed ethical considerations.
105
+ 9: Very Strong Accept: Technically flawless paper with groundbreaking impact on at least one area of AI and excellent impact on multiple areas of AI, with flawless evaluation, resources, and reproducibility, and no unaddressed ethical considerations.
106
+ 8: Strong Accept: Technically strong paper with, with novel ideas, excellent impact on at least one area of AI or high-to-excellent impact on multiple areas of AI, with excellent evaluation, resources, and reproducibility, and no unaddressed ethical considerations.
107
+ 7: Accept: Technically solid paper, with high impact on at least one sub-area of AI or moderate-to-high impact on more than one area of AI, with good-to-excellent evaluation, resources, reproducibility, and no unaddressed ethical considerations.
108
+ 6: Weak Accept: Technically solid, moderate-to-high impact paper, with no major concerns with respect to evaluation, resources, reproducibility, ethical considerations.
109
+ 5: Borderline accept: Technically solid paper where reasons to accept outweigh reasons to reject, e.g., limited evaluation. Please use sparingly.
110
+ 4: Borderline reject: Technically solid paper where reasons to reject, e.g., limited evaluation, outweigh reasons to accept, e.g., good evaluation. Please use sparingly.
111
+ 3: Reject: For instance, a paper with technical flaws, weak evaluation, inadequate reproducibility and incompletely addressed ethical considerations.
112
+ 2: Strong Reject: For instance, a paper with major technical flaws, and/or poor evaluation, limited impact, poor reproducibility and mostly unaddressed ethical considerations.
113
+ 1: Very Strong Reject: For instance, a paper with trivial results or unaddressed ethical considerations
114
+
115
+ 9. Confidence: Please provide a "confidence score" for your assessment of this submission to indicate how confident you are in your evaluation. Choices:
116
+ 5: You are absolutely certain about your assessment. You are very familiar with the related work and checked the math/other details carefully.
117
+ 4: You are confident in your assessment, but not absolutely certain. It is unlikely, but not impossible, that you did not understand some parts of the submission or that you are unfamiliar with some pieces of related work.
118
+ 3: You are fairly confident in your assessment. It is possible that you did not understand some parts of the submission or that you are unfamiliar with some pieces of related work. Math/other details were not carefully checked.
119
+ 2: You are willing to defend your assessment, but it is quite likely that you did not understand the central parts of the submission or that you are unfamiliar with some pieces of related work. Math/other details were not carefully checked.
120
+ 1: Your assessment is an educated guess. The submission is not in your area or the submission was difficult to understand. Math/other details were not carefully checked.
121
+ """
122
+ + template_instructions
123
+ )
124
+
125
+
126
+ def perform_review(
127
+ text,
128
+ model,
129
+ client,
130
+ num_reflections=1,
131
+ num_fs_examples=1,
132
+ num_reviews_ensemble=1,
133
+ temperature=0.75,
134
+ msg_history=None,
135
+ return_msg_history=False,
136
+ reviewer_system_prompt=reviewer_system_prompt_neg,
137
+ review_instruction_form=neurips_form,
138
+ ):
139
+ if num_fs_examples > 0:
140
+ fs_prompt = get_review_fewshot_examples(num_fs_examples)
141
+ base_prompt = review_instruction_form + fs_prompt
142
+ else:
143
+ base_prompt = review_instruction_form
144
+
145
+ base_prompt += f"""
146
+ Here is the paper you are asked to review:
147
+ ```
148
+ {text}
149
+ ```"""
150
+
151
+ if num_reviews_ensemble > 1:
152
+ llm_review, msg_histories = get_batch_responses_from_llm(
153
+ base_prompt,
154
+ model=model,
155
+ client=client,
156
+ system_message=reviewer_system_prompt,
157
+ print_debug=False,
158
+ msg_history=msg_history,
159
+ # Higher temperature to encourage diversity.
160
+ temperature=0.75,
161
+ n_responses=num_reviews_ensemble,
162
+ )
163
+ parsed_reviews = []
164
+ for idx, rev in enumerate(llm_review):
165
+ try:
166
+ parsed_reviews.append(extract_json_between_markers(rev))
167
+ except Exception as e:
168
+ print(f"Ensemble review {idx} failed: {e}")
169
+ parsed_reviews = [r for r in parsed_reviews if r is not None]
170
+ review = get_meta_review(model, client, temperature, parsed_reviews)
171
+
172
+ # take first valid in case meta-reviewer fails
173
+ if review is None:
174
+ review = parsed_reviews[0]
175
+
176
+ # Replace numerical scores with the average of the ensemble.
177
+ for score, limits in [
178
+ ("Originality", (1, 4)),
179
+ ("Quality", (1, 4)),
180
+ ("Clarity", (1, 4)),
181
+ ("Significance", (1, 4)),
182
+ ("Soundness", (1, 4)),
183
+ ("Presentation", (1, 4)),
184
+ ("Contribution", (1, 4)),
185
+ ("Overall", (1, 10)),
186
+ ("Confidence", (1, 5)),
187
+ ]:
188
+ scores = []
189
+ for r in parsed_reviews:
190
+ if score in r and limits[1] >= r[score] >= limits[0]:
191
+ scores.append(r[score])
192
+ review[score] = int(round(np.mean(scores)))
193
+
194
+ # Rewrite the message history with the valid one and new aggregated review.
195
+ msg_history = msg_histories[0][:-1]
196
+ msg_history += [
197
+ {
198
+ "role": "assistant",
199
+ "content": f"""
200
+ THOUGHT:
201
+ I will start by aggregating the opinions of {num_reviews_ensemble} reviewers that I previously obtained.
202
+
203
+ REVIEW JSON:
204
+ ```json
205
+ {json.dumps(review)}
206
+ ```
207
+ """,
208
+ }
209
+ ]
210
+ else:
211
+ llm_review, msg_history = get_response_from_llm(
212
+ base_prompt,
213
+ model=model,
214
+ client=client,
215
+ system_message=reviewer_system_prompt,
216
+ print_debug=False,
217
+ msg_history=msg_history,
218
+ temperature=temperature,
219
+ )
220
+ review = extract_json_between_markers(llm_review)
221
+
222
+ if num_reflections > 1:
223
+ for j in range(num_reflections - 1):
224
+ # print(f"Relection: {j + 2}/{num_reflections}")
225
+ text, msg_history = get_response_from_llm(
226
+ reviewer_reflection_prompt,
227
+ client=client,
228
+ model=model,
229
+ system_message=reviewer_system_prompt,
230
+ msg_history=msg_history,
231
+ temperature=temperature,
232
+ )
233
+ review = extract_json_between_markers(text)
234
+ assert review is not None, "Failed to extract JSON from LLM output"
235
+
236
+ if "I am done" in text:
237
+ # print(f"Review generation converged after {j + 2} iterations.")
238
+ break
239
+
240
+ if return_msg_history:
241
+ return review, msg_history
242
+ else:
243
+ return review
244
+
245
+
246
+ reviewer_reflection_prompt = """Round {current_round}/{num_reflections}.
247
+ In your thoughts, first carefully consider the accuracy and soundness of the review you just created.
248
+ Include any other factors that you think are important in evaluating the paper.
249
+ Ensure the review is clear and concise, and the JSON is in the correct format.
250
+ Do not make things overly complicated.
251
+ In the next attempt, try and refine and improve your review.
252
+ Stick to the spirit of the original review unless there are glaring issues.
253
+
254
+ Respond in the same format as before:
255
+ THOUGHT:
256
+ <THOUGHT>
257
+
258
+ REVIEW JSON:
259
+ ```json
260
+ <JSON>
261
+ ```
262
+
263
+ If there is nothing to improve, simply repeat the previous JSON EXACTLY after the thought and include "I am done" at the end of the thoughts but before the JSON.
264
+ ONLY INCLUDE "I am done" IF YOU ARE MAKING NO MORE CHANGES."""
265
+
266
+
267
+ def load_paper(pdf_path, num_pages=None, min_size=100):
268
+ try:
269
+ if num_pages is None:
270
+ text = pymupdf4llm.to_markdown(pdf_path)
271
+ else:
272
+ reader = PdfReader(pdf_path)
273
+ min_pages = min(len(reader.pages), num_pages)
274
+ text = pymupdf4llm.to_markdown(pdf_path, pages=list(range(min_pages)))
275
+ if len(text) < min_size:
276
+ raise Exception("Text too short")
277
+ except Exception as e:
278
+ print(f"Error with pymupdf4llm, falling back to pymupdf: {e}")
279
+ try:
280
+ doc = pymupdf.open(pdf_path) # open a document
281
+ if num_pages:
282
+ doc = doc[:num_pages]
283
+ text = ""
284
+ for page in doc: # iterate the document pages
285
+ text = text + page.get_text() # get plain text encoded as UTF-8
286
+ if len(text) < min_size:
287
+ raise Exception("Text too short")
288
+ except Exception as e:
289
+ print(f"Error with pymupdf, falling back to pypdf: {e}")
290
+ reader = PdfReader(pdf_path)
291
+ if num_pages is None:
292
+ text = "".join(page.extract_text() for page in reader.pages)
293
+ else:
294
+ text = "".join(page.extract_text() for page in reader.pages[:num_pages])
295
+ if len(text) < min_size:
296
+ raise Exception("Text too short")
297
+
298
+ return text
299
+
300
+
301
+ def load_review(path):
302
+ with open(path, "r") as json_file:
303
+ loaded = json.load(json_file)
304
+ return loaded["review"]
305
+
306
+
307
+ # get directory of this file
308
+ dir_path = os.path.dirname(os.path.realpath(__file__))
309
+
310
+ fewshot_papers = [
311
+ os.path.join(dir_path, "fewshot_examples/132_automated_relational.pdf"),
312
+ os.path.join(dir_path, "fewshot_examples/attention.pdf"),
313
+ os.path.join(dir_path, "fewshot_examples/2_carpe_diem.pdf"),
314
+ ]
315
+
316
+ fewshot_reviews = [
317
+ os.path.join(dir_path, "fewshot_examples/132_automated_relational.json"),
318
+ os.path.join(dir_path, "fewshot_examples/attention.json"),
319
+ os.path.join(dir_path, "fewshot_examples/2_carpe_diem.json"),
320
+ ]
321
+
322
+
323
+ def get_review_fewshot_examples(num_fs_examples=1):
324
+ fewshot_prompt = """
325
+ Below are some sample reviews, copied from previous machine learning conferences.
326
+ Note that while each review is formatted differently according to each reviewer's style, the reviews are well-structured and therefore easy to navigate.
327
+ """
328
+ for paper, review in zip(
329
+ fewshot_papers[:num_fs_examples], fewshot_reviews[:num_fs_examples]
330
+ ):
331
+ txt_path = paper.replace(".pdf", ".txt")
332
+ if os.path.exists(txt_path):
333
+ with open(txt_path, "r") as f:
334
+ paper_text = f.read()
335
+ else:
336
+ paper_text = load_paper(paper)
337
+ review_text = load_review(review)
338
+ fewshot_prompt += f"""
339
+ Paper:
340
+
341
+ ```
342
+ {paper_text}
343
+ ```
344
+
345
+ Review:
346
+
347
+ ```
348
+ {review_text}
349
+ ```
350
+ """
351
+
352
+ return fewshot_prompt
353
+
354
+
355
+ meta_reviewer_system_prompt = """You are an Area Chair at a machine learning conference.
356
+ You are in charge of meta-reviewing a paper that was reviewed by {reviewer_count} reviewers.
357
+ Your job is to aggregate the reviews into a single meta-review in the same format.
358
+ Be critical and cautious in your decision, find consensus, and respect the opinion of all the reviewers."""
359
+
360
+
361
+ def get_meta_review(model, client, temperature, reviews):
362
+ # Write a meta-review from a set of individual reviews
363
+ review_text = ""
364
+ for i, r in enumerate(reviews):
365
+ review_text += f"""
366
+ Review {i + 1}/{len(reviews)}:
367
+ ```
368
+ {json.dumps(r)}
369
+ ```
370
+ """
371
+ base_prompt = neurips_form + review_text
372
+
373
+ llm_review, msg_history = get_response_from_llm(
374
+ base_prompt,
375
+ model=model,
376
+ client=client,
377
+ system_message=meta_reviewer_system_prompt.format(reviewer_count=len(reviews)),
378
+ print_debug=False,
379
+ msg_history=None,
380
+ temperature=temperature,
381
+ )
382
+ meta_review = extract_json_between_markers(llm_review)
383
+ return meta_review
384
+
385
+
386
+ def perform_improvement(review, coder):
387
+ improvement_prompt = '''The following review has been created for your research paper:
388
+ """
389
+ {review}
390
+ """
391
+
392
+ Improve the text using the review.'''.format(
393
+ review=json.dumps(review)
394
+ )
395
+ coder_out = coder.run(improvement_prompt)
ai_scientist/perform_writeup.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import os.path as osp
5
+ import re
6
+ import shutil
7
+ import subprocess
8
+ from typing import Optional, Tuple
9
+
10
+ from ai_scientist.generate_ideas import search_for_papers
11
+ from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS
12
+
13
+
14
+ # GENERATE LATEX
15
+ def generate_latex(coder, folder_name, pdf_file, timeout=30, num_error_corrections=5):
16
+ folder = osp.abspath(folder_name)
17
+ cwd = osp.join(folder, "latex") # Fixed potential issue with path
18
+ writeup_file = osp.join(cwd, "template.tex")
19
+
20
+ # Check all references are valid and in the references.bib file
21
+ with open(writeup_file, "r") as f:
22
+ tex_text = f.read()
23
+ cites = re.findall(r"\\cite[a-z]*{([^}]*)}", tex_text)
24
+ references_bib = re.search(
25
+ r"\\begin{filecontents}{references.bib}(.*?)\\end{filecontents}",
26
+ tex_text,
27
+ re.DOTALL,
28
+ )
29
+ if references_bib is None:
30
+ print("No references.bib found in template.tex")
31
+ return
32
+ bib_text = references_bib.group(1)
33
+ cites = [cite.strip() for item in cites for cite in item.split(",")]
34
+ for cite in cites:
35
+ if cite not in bib_text:
36
+ print(f"Reference {cite} not found in references.")
37
+ prompt = f"""Reference {cite} not found in references.bib. Is this included under a different name?
38
+ If so, please modify the citation in template.tex to match the name in references.bib at the top. Otherwise, remove the cite."""
39
+ coder.run(prompt)
40
+
41
+ # Check all included figures are actually in the directory.
42
+ with open(writeup_file, "r") as f:
43
+ tex_text = f.read()
44
+ referenced_figs = re.findall(r"\\includegraphics.*?{(.*?)}", tex_text)
45
+ all_figs = [f for f in os.listdir(folder) if f.endswith(".png")]
46
+ for figure in referenced_figs:
47
+ if figure not in all_figs:
48
+ print(f"Figure {figure} not found in directory.")
49
+ prompt = f"""The image {figure} not found in the directory. The images in the directory are: {all_figs}.
50
+ Please ensure that the figure is in the directory and that the filename is correct. Check the notes to see what each figure contains."""
51
+ coder.run(prompt)
52
+
53
+ # Remove duplicate figures.
54
+ with open(writeup_file, "r") as f:
55
+ tex_text = f.read()
56
+ referenced_figs = re.findall(r"\\includegraphics.*?{(.*?)}", tex_text)
57
+ duplicates = {x for x in referenced_figs if referenced_figs.count(x) > 1}
58
+ if duplicates:
59
+ for dup in duplicates:
60
+ print(f"Duplicate figure found: {dup}.")
61
+ prompt = f"""Duplicate figures found: {dup}. Ensure any figure is only included once.
62
+ If duplicated, identify the best location for the figure and remove any other."""
63
+ coder.run(prompt)
64
+
65
+ # Remove duplicate section headers.
66
+ with open(writeup_file, "r") as f:
67
+ tex_text = f.read()
68
+ sections = re.findall(r"\\section{([^}]*)}", tex_text)
69
+ duplicates = {x for x in sections if sections.count(x) > 1}
70
+ if duplicates:
71
+ for dup in duplicates:
72
+ print(f"Duplicate section header found: {dup}")
73
+ prompt = f"""Duplicate section header found: {dup}. Ensure any section header is declared once.
74
+ If duplicated, identify the best location for the section header and remove any other."""
75
+ coder.run(prompt)
76
+
77
+ # Iteratively fix any LaTeX bugs
78
+ for i in range(num_error_corrections):
79
+ # Filter trivial bugs in chktex
80
+ check_output = os.popen(f"chktex {writeup_file} -q -n2 -n24 -n13 -n1").read()
81
+ if check_output:
82
+ prompt = f"""Please fix the following LaTeX errors in `template.tex` guided by the output of `chktek`:
83
+ {check_output}.
84
+
85
+ Make the minimal fix required and do not remove or change any packages.
86
+ Pay attention to any accidental uses of HTML syntax, e.g. </end instead of \\end.
87
+ """
88
+ coder.run(prompt)
89
+ else:
90
+ break
91
+ compile_latex(cwd, pdf_file, timeout=timeout)
92
+
93
+
94
+ def compile_latex(cwd, pdf_file, timeout=30):
95
+ print("GENERATING LATEX")
96
+
97
+ commands = [
98
+ ["pdflatex", "-interaction=nonstopmode", "template.tex"],
99
+ ["bibtex", "template"],
100
+ ["pdflatex", "-interaction=nonstopmode", "template.tex"],
101
+ ["pdflatex", "-interaction=nonstopmode", "template.tex"],
102
+ ]
103
+
104
+ for command in commands:
105
+ try:
106
+ result = subprocess.run(
107
+ command,
108
+ cwd=cwd,
109
+ stdout=subprocess.PIPE,
110
+ stderr=subprocess.PIPE,
111
+ text=True,
112
+ timeout=timeout,
113
+ )
114
+ print("Standard Output:\n", result.stdout)
115
+ print("Standard Error:\n", result.stderr)
116
+ except subprocess.TimeoutExpired:
117
+ print(f"Latex timed out after {timeout} seconds")
118
+ except subprocess.CalledProcessError as e:
119
+ print(f"Error running command {' '.join(command)}: {e}")
120
+
121
+ print("FINISHED GENERATING LATEX")
122
+
123
+ # Attempt to move the PDF to the desired location
124
+ try:
125
+ shutil.move(osp.join(cwd, "template.pdf"), pdf_file)
126
+ except FileNotFoundError:
127
+ print("Failed to rename PDF.")
128
+
129
+
130
+ per_section_tips = {
131
+ "Abstract": """
132
+ - TL;DR of the paper
133
+ - What are we trying to do and why is it relevant?
134
+ - Why is this hard?
135
+ - How do we solve it (i.e. our contribution!)
136
+ - How do we verify that we solved it (e.g. Experiments and results)
137
+
138
+ Please make sure the abstract reads smoothly and is well-motivated. This should be one continuous paragraph with no breaks between the lines.
139
+ """,
140
+ "Introduction": """
141
+ - Longer version of the Abstract, i.e. of the entire paper
142
+ - What are we trying to do and why is it relevant?
143
+ - Why is this hard?
144
+ - How do we solve it (i.e. our contribution!)
145
+ - How do we verify that we solved it (e.g. Experiments and results)
146
+ - New trend: specifically list your contributions as bullet points
147
+ - Extra space? Future work!
148
+ """,
149
+ "Related Work": """
150
+ - Academic siblings of our work, i.e. alternative attempts in literature at trying to solve the same problem.
151
+ - Goal is to “Compare and contrast” - how does their approach differ in either assumptions or method? If their method is applicable to our Problem Setting I expect a comparison in the experimental section. If not, there needs to be a clear statement why a given method is not applicable.
152
+ - Note: Just describing what another paper is doing is not enough. We need to compare and contrast.
153
+ """,
154
+ "Background": """
155
+ - Academic Ancestors of our work, i.e. all concepts and prior work that are required for understanding our method.
156
+ - Usually includes a subsection, Problem Setting, which formally introduces the problem setting and notation (Formalism) for our method. Highlights any specific assumptions that are made that are unusual.
157
+ - Note: If our paper introduces a novel problem setting as part of its contributions, it's best to have a separate Section.
158
+ """,
159
+ "Method": """
160
+ - What we do. Why we do it. All described using the general Formalism introduced in the Problem Setting and building on top of the concepts / foundations introduced in Background.
161
+ """,
162
+ "Experimental Setup": """
163
+ - How do we test that our stuff works? Introduces a specific instantiation of the Problem Setting and specific implementation details of our Method for this Problem Setting.
164
+ - Do not imagine unknown hardware details.
165
+ - Includes a description of the dataset, evaluation metrics, important hyperparameters, and implementation details.
166
+ """,
167
+ "Results": """
168
+ - Shows the results of running Method on our problem described in Experimental Setup.
169
+ - Includes statements on hyperparameters and other potential issues of fairness.
170
+ - Only includes results that have actually been run and saved in the logs. Do not hallucinate results that don't exist.
171
+ - If results exist: compares to baselines and includes statistics and confidence intervals.
172
+ - If results exist: includes ablation studies to show that specific parts of the method are relevant.
173
+ - Discusses limitations of the method.
174
+ - Make sure to include all the results from the experiments, and include all relevant figures.
175
+ """,
176
+ "Conclusion": """
177
+ - Brief recap of the entire paper.
178
+ - To keep going with the analogy, you can think of future work as (potential) academic offspring.
179
+ """,
180
+ }
181
+
182
+ error_list = """- Unenclosed math symbols
183
+ - Only reference figures that exist in our directory
184
+ - LaTeX syntax errors
185
+ - Numerical results that do not come from explicit experiments and logs
186
+ - Repeatedly defined figure labels
187
+ - References to papers that are not in the .bib file, DO NOT ADD ANY NEW CITATIONS!
188
+ - Unnecessary verbosity or repetition, unclear text
189
+ - Results or insights in the `notes.txt` that have not yet need included
190
+ - Any relevant figures that have not yet been included in the text
191
+ - Closing any \\begin{{figure}} with a \\end{{figure}} and \\begin{{table}} with a \\end{{table}}, etc.
192
+ - Duplicate headers, e.g. duplicated \\section{{Introduction}} or \\end{{document}}
193
+ - Unescaped symbols, e.g. shakespeare_char should be shakespeare\\_char in text
194
+ - Incorrect closing of environments, e.g. </end{{figure}}> instead of \\end{{figure}}
195
+ """
196
+
197
+ refinement_prompt = (
198
+ """Great job! Now criticize and refine only the {section} that you just wrote.
199
+ Make this complete in this pass, do not leave any placeholders.
200
+
201
+ Pay particular attention to fixing any errors such as:
202
+ """
203
+ + error_list
204
+ )
205
+
206
+ second_refinement_prompt = (
207
+ """Criticize and refine the {section} only. Recall the advice:
208
+ {tips}
209
+ Make this complete in this pass, do not leave any placeholders.
210
+
211
+ Pay attention to how it fits in with the rest of the paper.
212
+ Identify any redundancies (e.g. repeated figures or repeated text), if there are any, decide where in the paper things should be cut.
213
+ Identify where we can save space, and be more concise without weakening the message of the text.
214
+ Fix any remaining errors as before:
215
+ """
216
+ + error_list
217
+ )
218
+
219
+ # CITATION HELPERS
220
+ citation_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
221
+ You have already written an initial draft of the paper and now you are looking to add missing citations to related papers throughout the paper.
222
+ The related work section already has some initial comments on which papers to add and discuss.
223
+
224
+ Focus on completing the existing write-up and do not add entirely new elements unless necessary.
225
+ Ensure every point in the paper is substantiated with sufficient evidence.
226
+ Feel free to add more cites to a particular point if there is only one or two references.
227
+ Ensure no paper is cited without a corresponding reference in the `references.bib` file.
228
+ Ensure each paragraph of the related work has sufficient background, e.g. a few papers cited.
229
+ You will be given access to the Semantic Scholar API, only add citations that you have found using the API.
230
+ Aim to discuss a broad range of relevant papers, not just the most popular ones.
231
+ Make sure not to copy verbatim from prior literature to avoid plagiarism.
232
+
233
+ You will be prompted to give a precise description of where and how to add the cite, and a search query for the paper to be cited.
234
+ Finally, you will select the most relevant cite from the search results (top 10 results will be shown).
235
+ You will have {total_rounds} rounds to add to the references, but do not need to use them all.
236
+
237
+ DO NOT ADD A CITATION THAT ALREADY EXISTS!"""
238
+
239
+ citation_first_prompt = '''Round {current_round}/{total_rounds}:
240
+
241
+ You have written this LaTeX draft so far:
242
+
243
+ """
244
+ {draft}
245
+ """
246
+
247
+ Identify the most important citation that you still need to add, and the query to find the paper.
248
+
249
+ Respond in the following format:
250
+
251
+ THOUGHT:
252
+ <THOUGHT>
253
+
254
+ RESPONSE:
255
+ ```json
256
+ <JSON>
257
+ ```
258
+
259
+ In <THOUGHT>, first briefly reason over the paper and identify where citations should be added.
260
+ If no more citations are needed, add "No more citations needed" to your thoughts.
261
+ Do not add "No more citations needed" if you are adding citations this round.
262
+
263
+ In <JSON>, respond in JSON format with the following fields:
264
+ - "Description": A precise description of the required edit, along with the proposed text and location where it should be made.
265
+ - "Query": The search query to find the paper (e.g. attention is all you need).
266
+
267
+ Ensure the description is sufficient to make the change without further context. Someone else will make the change.
268
+ The query will work best if you are able to recall the exact name of the paper you are looking for, or the authors.
269
+ This JSON will be automatically parsed, so ensure the format is precise.'''
270
+
271
+ citation_second_prompt = """Search has recovered the following articles:
272
+
273
+ {papers}
274
+
275
+ Respond in the following format:
276
+
277
+ THOUGHT:
278
+ <THOUGHT>
279
+
280
+ RESPONSE:
281
+ ```json
282
+ <JSON>
283
+ ```
284
+
285
+ In <THOUGHT>, first briefly reason over the search results and identify which citation best fits your paper and the location is to be added at.
286
+ If none are appropriate, add "Do not add any" to your thoughts.
287
+
288
+ In <JSON>, respond in JSON format with the following fields:
289
+ - "Selected": A list of the indices of the selected papers to be cited, e.g. "[0, 1]". Can be "[]" if no papers are selected. This must be a string.
290
+ - "Description": Update the previous description of the required edit if needed. Ensure that any cites precisely match the name in the bibtex!!!
291
+
292
+ Do not select papers that are already in the `references.bib` file at the top of the draft, or if the same citation exists under a different name.
293
+ This JSON will be automatically parsed, so ensure the format is precise."""
294
+
295
+
296
+ def get_citation_aider_prompt(
297
+ client, model, draft, current_round, total_rounds, engine="semanticscholar"
298
+ ) -> Tuple[Optional[str], bool]:
299
+ msg_history = []
300
+ try:
301
+ text, msg_history = get_response_from_llm(
302
+ citation_first_prompt.format(
303
+ draft=draft, current_round=current_round, total_rounds=total_rounds
304
+ ),
305
+ client=client,
306
+ model=model,
307
+ system_message=citation_system_msg.format(total_rounds=total_rounds),
308
+ msg_history=msg_history,
309
+ )
310
+ if "No more citations needed" in text:
311
+ print("No more citations needed.")
312
+ return None, True
313
+
314
+ ## PARSE OUTPUT
315
+ json_output = extract_json_between_markers(text)
316
+ assert json_output is not None, "Failed to extract JSON from LLM output"
317
+ query = json_output["Query"]
318
+ papers = search_for_papers(query, engine=engine)
319
+ except Exception as e:
320
+ print(f"Error: {e}")
321
+ return None, False
322
+
323
+ if papers is None:
324
+ print("No papers found.")
325
+ return None, False
326
+
327
+ paper_strings = []
328
+ for i, paper in enumerate(papers):
329
+ paper_strings.append(
330
+ """{i}: {title}. {authors}. {venue}, {year}.\nAbstract: {abstract}""".format(
331
+ i=i,
332
+ title=paper["title"],
333
+ authors=paper["authors"],
334
+ venue=paper["venue"],
335
+ year=paper["year"],
336
+ abstract=paper["abstract"],
337
+ )
338
+ )
339
+ papers_str = "\n\n".join(paper_strings)
340
+
341
+ try:
342
+ text, msg_history = get_response_from_llm(
343
+ citation_second_prompt.format(
344
+ papers=papers_str,
345
+ current_round=current_round,
346
+ total_rounds=total_rounds,
347
+ ),
348
+ client=client,
349
+ model=model,
350
+ system_message=citation_system_msg.format(total_rounds=total_rounds),
351
+ msg_history=msg_history,
352
+ )
353
+ if "Do not add any" in text:
354
+ print("Do not add any.")
355
+ return None, False
356
+ ## PARSE OUTPUT
357
+ json_output = extract_json_between_markers(text)
358
+ assert json_output is not None, "Failed to extract JSON from LLM output"
359
+ desc = json_output["Description"]
360
+ selected_papers = json_output["Selected"]
361
+ selected_papers = str(selected_papers)
362
+
363
+ # convert to list
364
+ if selected_papers != "[]":
365
+ selected_papers = list(map(int, selected_papers.strip("[]").split(",")))
366
+ assert all(
367
+ [0 <= i < len(papers) for i in selected_papers]
368
+ ), "Invalid paper index"
369
+ bibtexs = [papers[i]["citationStyles"]["bibtex"] for i in selected_papers]
370
+ bibtex_string = "\n".join(bibtexs)
371
+ else:
372
+ return None, False
373
+
374
+ except Exception as e:
375
+ print(f"Error: {e}")
376
+ return None, False
377
+
378
+ # Add citation to draft
379
+ aider_format = '''The following citations have just been added to the end of the `references.bib` file definition at the top of the file:
380
+ """
381
+ {bibtex}
382
+ """
383
+ You do not need to add them yourself.
384
+ ABSOLUTELY DO NOT ADD IT AGAIN!!!
385
+
386
+ Make the proposed change to the draft incorporating these new cites:
387
+ {description}
388
+
389
+ Use your judgment for whether these should be cited anywhere else.
390
+ Make sure that any citation precisely matches the name in `references.bib`. Change its name to the correct name in the bibtex if needed.
391
+ Ensure the citation is well-integrated into the text.'''
392
+
393
+ aider_prompt = (
394
+ aider_format.format(bibtex=bibtex_string, description=desc)
395
+ + """\n You must use \cite or \citet to reference papers, do not manually type out author names."""
396
+ )
397
+ return aider_prompt, False
398
+
399
+
400
+ # PERFORM WRITEUP
401
+ def perform_writeup(
402
+ idea, folder_name, coder, cite_client, cite_model, num_cite_rounds=20, engine="semanticscholar"
403
+ ):
404
+ # CURRENTLY ASSUMES LATEX
405
+ abstract_prompt = f"""We've provided the `latex/template.tex` file to the project. We will be filling it in section by section.
406
+
407
+ First, please fill in the "Title" and "Abstract" sections of the writeup.
408
+
409
+ Some tips are provided below:
410
+ {per_section_tips["Abstract"]}
411
+
412
+ Before every paragraph, please include a brief description of what you plan to write in that paragraph in a comment.
413
+
414
+ Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
415
+ """
416
+ coder_out = coder.run(abstract_prompt)
417
+ coder_out = coder.run(
418
+ refinement_prompt.format(section="Abstract")
419
+ .replace(r"{{", "{")
420
+ .replace(r"}}", "}")
421
+ )
422
+ for section in [
423
+ "Introduction",
424
+ "Background",
425
+ "Method",
426
+ "Experimental Setup",
427
+ "Results",
428
+ "Conclusion",
429
+ ]:
430
+ section_prompt = f"""Please fill in the {section} of the writeup. Some tips are provided below:
431
+ {per_section_tips[section]}
432
+
433
+ Be sure to use \cite or \citet where relevant, referring to the works provided in the file.
434
+ Do not cite anything that is not already in `references.bib`. Do not add any new entries to this.
435
+
436
+ Keep the experimental results (figures and tables) only in the Results section, and make sure that any captions are filled in.
437
+ In this pass, do not reference anything in later sections of the paper.
438
+
439
+ Before every paragraph, please include a brief description of what you plan to write in that paragraph in a comment.
440
+
441
+ Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
442
+ """
443
+ coder_out = coder.run(section_prompt)
444
+ coder_out = coder.run(
445
+ refinement_prompt.format(section=section)
446
+ .replace(r"{{", "{")
447
+ .replace(r"}}", "}")
448
+ )
449
+
450
+ # SKETCH THE RELATED WORK
451
+ section_prompt = f"""Please fill in the Related Work of the writeup. Some tips are provided below:
452
+
453
+ {per_section_tips["Related Work"]}
454
+
455
+ For this section, very briefly sketch out the structure of the section, and clearly indicate what papers you intend to include.
456
+ Do this all in LaTeX comments using %.
457
+ The related work should be concise, only plan to discuss the most relevant work.
458
+ Do not modify `references.bib` to add any new citations, this will be filled in at a later stage.
459
+
460
+ Be sure to first name the file and use *SEARCH/REPLACE* blocks to perform these edits.
461
+ """
462
+ coder_out = coder.run(section_prompt)
463
+
464
+ # Fill paper with cites.
465
+ for _ in range(num_cite_rounds):
466
+ with open(osp.join(folder_name, "latex", "template.tex"), "r") as f:
467
+ draft = f.read()
468
+ prompt, done = get_citation_aider_prompt(
469
+ cite_client, cite_model, draft, _, num_cite_rounds, engine=engine
470
+ )
471
+ if done:
472
+ break
473
+ if prompt is not None:
474
+ # extract bibtex string
475
+ bibtex_string = prompt.split('"""')[1]
476
+ # insert this into draft before the "\end{filecontents}" line
477
+ search_str = r"\end{filecontents}"
478
+ draft = draft.replace(search_str, f"{bibtex_string}{search_str}")
479
+ with open(osp.join(folder_name, "latex", "template.tex"), "w") as f:
480
+ f.write(draft)
481
+ coder_out = coder.run(prompt)
482
+
483
+ coder_out = coder.run(
484
+ refinement_prompt.format(section="Related Work")
485
+ .replace(r"{{", "{")
486
+ .replace(r"}}", "}")
487
+ )
488
+
489
+ ## SECOND REFINEMENT LOOP
490
+ coder.run(
491
+ """Great job! Now that there is a complete draft of the entire paper, let's refine each section again.
492
+ First, re-think the Title if necessary. Keep this concise and descriptive of the paper's concept, but try by creative with it."""
493
+ )
494
+ for section in [
495
+ "Abstract",
496
+ "Related Work",
497
+ "Introduction",
498
+ "Background",
499
+ "Method",
500
+ "Experimental Setup",
501
+ "Results",
502
+ "Conclusion",
503
+ ]:
504
+ coder_out = coder.run(
505
+ second_refinement_prompt.format(
506
+ section=section, tips=per_section_tips[section]
507
+ )
508
+ .replace(r"{{", "{")
509
+ .replace(r"}}", "}")
510
+ )
511
+
512
+ generate_latex(coder, folder_name, f"{folder_name}/{idea['Name']}.pdf")
513
+
514
+
515
+ if __name__ == "__main__":
516
+ from aider.coders import Coder
517
+ from aider.models import Model
518
+ from aider.io import InputOutput
519
+ import json
520
+
521
+ parser = argparse.ArgumentParser(description="Perform writeup for a project")
522
+ parser.add_argument("--folder", type=str)
523
+ parser.add_argument("--no-writing", action="store_true", help="Only generate")
524
+ parser.add_argument(
525
+ "--model",
526
+ type=str,
527
+ default="gpt-4o-2024-05-13",
528
+ choices=AVAILABLE_LLMS,
529
+ help="Model to use for AI Scientist.",
530
+ )
531
+ parser.add_argument(
532
+ "--engine",
533
+ type=str,
534
+ default="semanticscholar",
535
+ choices=["semanticscholar", "openalex"],
536
+ help="Scholar engine to use.",
537
+ )
538
+ args = parser.parse_args()
539
+ client, client_model = create_client(args.model)
540
+ print("Make sure you cleaned the Aider logs if re-generating the writeup!")
541
+ folder_name = args.folder
542
+ idea_name = osp.basename(folder_name)
543
+ exp_file = osp.join(folder_name, "experiment.py")
544
+ vis_file = osp.join(folder_name, "plot.py")
545
+ notes = osp.join(folder_name, "notes.txt")
546
+ model = args.model
547
+ writeup_file = osp.join(folder_name, "latex", "template.tex")
548
+ ideas_file = osp.join(folder_name, "ideas.json")
549
+ with open(ideas_file, "r") as f:
550
+ ideas = json.load(f)
551
+ for idea in ideas:
552
+ if idea["Name"] in idea_name:
553
+ print(f"Found idea: {idea['Name']}")
554
+ break
555
+ if idea["Name"] not in idea_name:
556
+ raise ValueError(f"Idea {idea_name} not found")
557
+ fnames = [exp_file, writeup_file, notes]
558
+ io = InputOutput(yes=True, chat_history_file=f"{folder_name}/{idea_name}_aider.txt")
559
+ if args.model == "deepseek-coder-v2-0724":
560
+ main_model = Model("deepseek/deepseek-coder")
561
+ elif args.model == "llama3.1-405b":
562
+ main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
563
+ else:
564
+ main_model = Model(model)
565
+ coder = Coder.create(
566
+ main_model=main_model,
567
+ fnames=fnames,
568
+ io=io,
569
+ stream=False,
570
+ use_git=False,
571
+ edit_format="diff",
572
+ )
573
+ if args.no_writing:
574
+ generate_latex(coder, args.folder, f"{args.folder}/test.pdf")
575
+ else:
576
+ try:
577
+ perform_writeup(idea, folder_name, coder, client, client_model, engine=args.engine)
578
+ except Exception as e:
579
+ print(f"Failed to perform writeup: {e}")