SNUMPR commited on
Commit
b56394a
1 Parent(s): 6604444

Upload 40 files

Browse files
Files changed (41) hide show
  1. .gitattributes +2 -0
  2. models/instructions_processed_LP/ALFRED_task_helper.py +296 -0
  3. models/instructions_processed_LP/BERT/best_models/base.pt +3 -0
  4. models/instructions_processed_LP/BERT/best_models/mrecep.pt +3 -0
  5. models/instructions_processed_LP/BERT/best_models/object.pt +3 -0
  6. models/instructions_processed_LP/BERT/best_models/parent.pt +3 -0
  7. models/instructions_processed_LP/BERT/best_models/sliced.pt +3 -0
  8. models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/correct_labels_dict_ppdl.p +0 -0
  9. models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/correct_template_by_label_ppdl.p +0 -0
  10. models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/obj2idx.p +0 -0
  11. models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/obj2idx_new_split.p +0 -0
  12. models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/recep2idx.p +0 -0
  13. models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/recep2idx_new_split.p +0 -0
  14. models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/template_by_label.p +0 -0
  15. models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/toggle2idx.p +0 -0
  16. models/instructions_processed_LP/BERT/data/alfred_data/create_text_with_pddl_low_appended.py +112 -0
  17. models/instructions_processed_LP/BERT/data/alfred_data/create_text_with_pddl_low_appended_for_new_split.py +191 -0
  18. models/instructions_processed_LP/BERT/data/alfred_data/tests_seen_task_desc_to_task_id_oct24.p +0 -0
  19. models/instructions_processed_LP/BERT/data/alfred_data/tests_seen_text_with_ppdl_low_appended_new_split_oct24.p +0 -0
  20. models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_task_desc_to_task_id_oct24.p +0 -0
  21. models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_text_with_ppdl_low_appended_GT.p +0 -0
  22. models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_text_with_ppdl_low_appended_new_split_oct24.p +0 -0
  23. models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_text_with_ppdl_low_appended_test_unseen_new_split_GT.p +0 -0
  24. models/instructions_processed_LP/BERT/data/alfred_data/train_task_desc_to_task_id_oct24.p +3 -0
  25. models/instructions_processed_LP/BERT/data/alfred_data/train_text_with_ppdl_low_appended_new_split_oct24.p +3 -0
  26. models/instructions_processed_LP/BERT/data/alfred_data/valid_seen_task_desc_to_task_id_oct24.p +0 -0
  27. models/instructions_processed_LP/BERT/data/alfred_data/valid_seen_text_with_ppdl_low_appended_new_split_oct24.p +0 -0
  28. models/instructions_processed_LP/BERT/data/alfred_data/valid_unseen_task_desc_to_task_id_oct24.p +0 -0
  29. models/instructions_processed_LP/BERT/data/alfred_data/valid_unseen_text_with_ppdl_low_appended_new_split_oct24.p +0 -0
  30. models/instructions_processed_LP/BERT/end_to_end_outputs.py +210 -0
  31. models/instructions_processed_LP/BERT/train_bert_args.py +301 -0
  32. models/instructions_processed_LP/BERT/train_bert_base.py +179 -0
  33. models/instructions_processed_LP/compare_BERT_pred_with_GT_oct24.py +83 -0
  34. models/instructions_processed_LP/instruction2_params_tests_seen_appended_new_split_oct24.p +0 -0
  35. models/instructions_processed_LP/instruction2_params_tests_seen_new_split_GT_oct24.p +0 -0
  36. models/instructions_processed_LP/instruction2_params_tests_unseen_appended_new_split_oct24.p +0 -0
  37. models/instructions_processed_LP/instruction2_params_tests_unseen_new_split_GT_oct24.p +0 -0
  38. models/instructions_processed_LP/instruction2_params_valid_seen_appended_new_split_oct24.p +0 -0
  39. models/instructions_processed_LP/instruction2_params_valid_seen_new_split_GT_oct24.p +0 -0
  40. models/instructions_processed_LP/instruction2_params_valid_unseen_appended_new_split_oct24.p +0 -0
  41. models/instructions_processed_LP/instruction2_params_valid_unseen_new_split_GT_oct24.p +0 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/instructions_processed_LP/BERT/data/alfred_data/train_task_desc_to_task_id_oct24.p filter=lfs diff=lfs merge=lfs -text
37
+ models/instructions_processed_LP/BERT/data/alfred_data/train_text_with_ppdl_low_appended_new_split_oct24.p filter=lfs diff=lfs merge=lfs -text
models/instructions_processed_LP/ALFRED_task_helper.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Sat Mar 27 16:49:38 2021
5
+
6
+ @author: soyeonmin
7
+ """
8
+ import pickle
9
+ import alfred_utils.gen.constants as constants
10
+ import string
11
+
12
+ exclude = set(string.punctuation)
13
+ task_type_dict = {2: 'pick_and_place_simple',
14
+ 5: 'look_at_obj_in_light',
15
+ 1: 'pick_and_place_with_movable_recep',
16
+ 3: 'pick_two_obj_and_place',
17
+ 6: 'pick_clean_then_place_in_recep',
18
+ 4: 'pick_heat_then_place_in_recep',
19
+ 0: 'pick_cool_then_place_in_recep'}
20
+
21
+
22
+ def read_test_dict(test, appended, unseen):
23
+ if test:
24
+ if appended:
25
+ if unseen:
26
+ return pickle.load(open("models/instructions_processed_LP/instruction2_params_tests_unseen_new_split_GT_aug28.p", "rb"))
27
+ else:
28
+ return pickle.load(open("models/instructions_processed_LP/instruction2_params_tests_seen_new_split_GT_aug28.p", "rb"))
29
+ else:
30
+ if unseen:
31
+ # return pickle.load(open("models/instructions_processed_LP/instruction2_params_test_unseen_916_noappended.p", "rb"))
32
+
33
+ # REALFRED
34
+ return pickle.load(open("models/instructions_processed_LP/instruction2_params_test_unseen_noappended_new_split_bert_trained.p", "rb"))
35
+
36
+ else:
37
+ # return pickle.load(open("models/instructions_processed_LP/instruction2_params_test_seen_916_noappended.p", "rb"))
38
+ # REALFRED
39
+ return pickle.load(open("models/instructions_processed_LP/instruction2_params_test_seen_noappended_new_split_bert_trained.p", "rb"))
40
+
41
+ else:
42
+ if appended:
43
+ if unseen:
44
+ # return pickle.load(open("models/instructions_processed_LP/instruction2_params_val_unseen_appended.p", "rb"))
45
+ # return pickle.load(open("models/instructions_processed_LP/instruction2_params_valid_unseen_new_split_GT_aug28.p", "rb"))
46
+
47
+ # tests_unseen
48
+ return pickle.load(open("models/instructions_processed_LP/instruction2_params_valid_unseen_new_split_GT_aug28_t.p", "rb"))
49
+ else:
50
+ # # return pickle.load(open("models/instructions_processed_LP/instruction2_params_val_seen_appended.p", "rb"))
51
+ # return pickle.load(open("models/instructions_processed_LP/instruction2_params_valid_seen_new_split_GT_aug28.p", "rb"))
52
+
53
+ # tests_seen
54
+ return pickle.load(open("models/instructions_processed_LP/instruction2_params_valid_seen_new_split_GT_aug28_t.p", "rb"))
55
+ else:
56
+ if unseen:
57
+ # REALFRED
58
+ # return pickle.load(open("models/instructions_processed_LP/instruction2_params_val_unseen_916_noappended.p", "rb"))
59
+ # return pickle.load(open("models/instructions_processed_LP/instruction2_params_valid_unseen_noappended_new_split.p", "rb"))
60
+ return pickle.load(open("models/instructions_processed_LP/instruction2_params_valid_unseen_new_split_GT_aug28_t.p", "rb"))
61
+ else:
62
+ # return pickle.load(open("models/instructions_processed_LP/instruction2_params_val_seen_916_noappended.p", "rb"))
63
+ return pickle.load(open("models/instructions_processed_LP/instruction2_params_valid_seen_new_split_GT_aug28_t.p", "rb"))
64
+
65
+ def exist_or_no(string):
66
+ if string == '' or string == False:
67
+ return 0
68
+ else:
69
+ return 1
70
+
71
+ def none_or_str(string):
72
+ if string == '':
73
+ return None
74
+ else:
75
+ return string
76
+
77
+ def get_arguments_test(test_dict, instruction):
78
+ task_type, mrecep_target, object_target, parent_target, sliced = \
79
+ test_dict[instruction]['task_type'], test_dict[instruction]['mrecep_target'], test_dict[instruction]['object_target'], test_dict[instruction]['parent_target'],\
80
+ test_dict[instruction]['sliced']
81
+
82
+ if isinstance(task_type, int):
83
+ task_type = task_type_dict[task_type]
84
+ return instruction, task_type, mrecep_target, object_target, parent_target, sliced
85
+
86
+
87
+ def get_arguments(traj_data):
88
+ task_type = traj_data['task_type']
89
+ try:
90
+ # r_idx = traj_data['repeat_idx']
91
+ r_idx = traj_data['ann']['repeat_idx']
92
+ except:
93
+ r_idx = 0
94
+ language_goal_instr = traj_data['turk_annotations']['anns'][r_idx]['task_desc']
95
+
96
+ sliced = exist_or_no(traj_data['pddl_params']['object_sliced'])
97
+ mrecep_target = none_or_str(traj_data['pddl_params']['mrecep_target'])
98
+ object_target = none_or_str(traj_data['pddl_params']['object_target'])
99
+ parent_target = none_or_str(traj_data['pddl_params']['parent_target'])
100
+ #toggle_target = none_or_str(traj_data['pddl_params']['toggle_target'])
101
+
102
+ return language_goal_instr, task_type, mrecep_target, object_target, parent_target, sliced
103
+
104
+ def add_target(target, target_action, list_of_actions):
105
+ if target in [a for a in constants.OPENABLE_CLASS_LIST if not(a == 'Box')]:
106
+ list_of_actions.append((target, "OpenObject"))
107
+ list_of_actions.append((target, target_action))
108
+ if target in [a for a in constants.OPENABLE_CLASS_LIST if not(a == 'Box')]:
109
+ list_of_actions.append((target, "CloseObject"))
110
+ return list_of_actions
111
+
112
+ def determine_consecutive_interx(list_of_actions, previous_pointer, sliced=False):
113
+ returned, target_instance = False, None
114
+ if previous_pointer <= len(list_of_actions)-1:
115
+ if list_of_actions[previous_pointer][0] == list_of_actions[previous_pointer+1][0]:
116
+ returned = True
117
+ #target_instance = list_of_target_instance[-1] #previous target
118
+ target_instance = list_of_actions[previous_pointer][0]
119
+ #Micorwave or Fridge
120
+ elif list_of_actions[previous_pointer][1] == "OpenObject" and list_of_actions[previous_pointer+1][1] == "PickupObject":
121
+ returned = True
122
+ #target_instance = list_of_target_instance[0]
123
+ target_instance = list_of_actions[0][0]
124
+ if sliced:
125
+ #target_instance = list_of_target_instance[3]
126
+ target_instance = list_of_actions[3][0]
127
+ #Micorwave or Fridge
128
+ elif list_of_actions[previous_pointer][1] == "PickupObject" and list_of_actions[previous_pointer+1][1] == "CloseObject":
129
+ returned = True
130
+ #target_instance = list_of_target_instance[-2] #e.g. Fridge
131
+ target_instance = list_of_actions[previous_pointer-1][0]
132
+ #Faucet
133
+ elif list_of_actions[previous_pointer+1][0] == "Faucet" and list_of_actions[previous_pointer+1][1] in ["ToggleObjectOn", "ToggleObjectOff"]:
134
+ returned = True
135
+ target_instance = "Faucet"
136
+ #Pick up after faucet
137
+ elif list_of_actions[previous_pointer][0] == "Faucet" and list_of_actions[previous_pointer+1][1] == "PickupObject":
138
+ returned = True
139
+ #target_instance = list_of_target_instance[0]
140
+ target_instance = list_of_actions[0][0]
141
+ if sliced:
142
+ #target_instance = list_of_target_instance[3]
143
+ target_instance = list_of_actions[3][0]
144
+ return returned, target_instance
145
+
146
+ def get_list_of_highlevel_actions(traj_data, test=False, test_dict=None, args_nonsliced=False, appended=False):
147
+ if not(test):
148
+ language_goal, task_type, mrecep_target, obj_target, parent_target, sliced = get_arguments(traj_data)
149
+ if test:
150
+ r_idx = traj_data['ann']['repeat_idx']
151
+ instruction = traj_data['turk_annotations']['anns'][r_idx]['task_desc']
152
+ # aug28
153
+
154
+ #if appended:
155
+ instruction = instruction.lower()
156
+ instruction = ''.join(ch for ch in instruction if ch not in exclude)
157
+ language_goal, task_type, mrecep_target, obj_target, parent_target, sliced = get_arguments_test(test_dict, instruction)
158
+
159
+ #obj_target = 'Tomato'
160
+ #mrecep_target = "Plate"
161
+ if parent_target == "Sink":
162
+ parent_target = "SinkBasin"
163
+ if parent_target == "Bathtub":
164
+ parent_target = "BathtubBasin"
165
+
166
+ #Change to this after the sliced happens
167
+ if args_nonsliced:
168
+ if sliced == 1:
169
+ obj_target = obj_target +'Sliced'
170
+ #Map sliced as the same place in the map, but like "|SinkBasin" look at the objectid
171
+
172
+
173
+ categories_in_inst = []
174
+ list_of_highlevel_actions = []
175
+ second_object = []
176
+ caution_pointers = []
177
+ #obj_target = "Tomato"
178
+
179
+ #if sliced:
180
+ # obj_target = obj_target +'Sliced'
181
+
182
+ if sliced == 1:
183
+ list_of_highlevel_actions.append(("Knife", "PickupObject"))
184
+ list_of_highlevel_actions.append((obj_target, "SliceObject"))
185
+ caution_pointers.append(len(list_of_highlevel_actions))
186
+ list_of_highlevel_actions.append(("SinkBasin", "PutObject"))
187
+ categories_in_inst.append(obj_target)
188
+
189
+ if sliced:
190
+ obj_target = obj_target +'Sliced'
191
+
192
+
193
+ if task_type == 'pick_cool_then_place_in_recep': #0 in new_labels
194
+ list_of_highlevel_actions.append((obj_target, "PickupObject"))
195
+ caution_pointers.append(len(list_of_highlevel_actions))
196
+ list_of_highlevel_actions = add_target("Fridge", "PutObject", list_of_highlevel_actions)
197
+ list_of_highlevel_actions.append(("Fridge", "OpenObject"))
198
+ list_of_highlevel_actions.append((obj_target, "PickupObject"))
199
+ list_of_highlevel_actions.append(("Fridge", "CloseObject"))
200
+ caution_pointers.append(len(list_of_highlevel_actions))
201
+ list_of_highlevel_actions = add_target(parent_target, "PutObject", list_of_highlevel_actions)
202
+ categories_in_inst.append(obj_target)
203
+ categories_in_inst.append("Fridge")
204
+ categories_in_inst.append(parent_target)
205
+
206
+ elif task_type == 'pick_and_place_with_movable_recep': #1 in new_labels
207
+ list_of_highlevel_actions.append((obj_target, "PickupObject"))
208
+ caution_pointers.append(len(list_of_highlevel_actions))
209
+ list_of_highlevel_actions = add_target(mrecep_target, "PutObject", list_of_highlevel_actions)
210
+ list_of_highlevel_actions.append((mrecep_target, "PickupObject"))
211
+ caution_pointers.append(len(list_of_highlevel_actions))
212
+ list_of_highlevel_actions = add_target(parent_target, "PutObject", list_of_highlevel_actions)
213
+ categories_in_inst.append(obj_target)
214
+ categories_in_inst.append(mrecep_target)
215
+ categories_in_inst.append(parent_target)
216
+
217
+ elif task_type == 'pick_and_place_simple':#2 in new_labels
218
+ list_of_highlevel_actions.append((obj_target, "PickupObject"))
219
+ caution_pointers.append(len(list_of_highlevel_actions))
220
+ list_of_highlevel_actions = add_target(parent_target, "PutObject", list_of_highlevel_actions)
221
+ #list_of_highlevel_actions.append((parent_target, "PutObject"))
222
+ categories_in_inst.append(obj_target)
223
+ categories_in_inst.append(parent_target)
224
+
225
+
226
+ elif task_type == 'pick_heat_then_place_in_recep': #4 in new_labels
227
+ list_of_highlevel_actions.append((obj_target, "PickupObject"))
228
+ caution_pointers.append(len(list_of_highlevel_actions))
229
+ list_of_highlevel_actions = add_target("Microwave", "PutObject", list_of_highlevel_actions)
230
+ list_of_highlevel_actions.append(("Microwave", "ToggleObjectOn" ))
231
+ list_of_highlevel_actions.append(("Microwave", "ToggleObjectOff" ))
232
+ list_of_highlevel_actions.append(("Microwave", "OpenObject"))
233
+ list_of_highlevel_actions.append((obj_target, "PickupObject"))
234
+ list_of_highlevel_actions.append(("Microwave", "CloseObject"))
235
+ caution_pointers.append(len(list_of_highlevel_actions))
236
+ list_of_highlevel_actions = add_target(parent_target, "PutObject", list_of_highlevel_actions)
237
+ categories_in_inst.append(obj_target)
238
+ categories_in_inst.append("Microwave")
239
+ categories_in_inst.append(parent_target)
240
+
241
+ elif task_type == 'pick_two_obj_and_place': #3 in new_labels
242
+ list_of_highlevel_actions.append((obj_target, "PickupObject"))
243
+ caution_pointers.append(len(list_of_highlevel_actions))
244
+ list_of_highlevel_actions = add_target(parent_target, "PutObject", list_of_highlevel_actions)
245
+ if parent_target in constants.OPENABLE_CLASS_LIST:
246
+ second_object = [False] * 4
247
+ else:
248
+ second_object = [False] * 2
249
+ if sliced:
250
+ second_object = second_object + [False] * 3
251
+ caution_pointers.append(len(list_of_highlevel_actions))
252
+ list_of_highlevel_actions.append((obj_target, "PickupObject"))
253
+ #caution_pointers.append(len(list_of_highlevel_actions))
254
+ second_object.append(True)
255
+ caution_pointers.append(len(list_of_highlevel_actions))
256
+ list_of_highlevel_actions = add_target(parent_target, "PutObject", list_of_highlevel_actions)
257
+ second_object.append(False)
258
+ categories_in_inst.append(obj_target)
259
+ categories_in_inst.append(parent_target)
260
+
261
+
262
+ elif task_type == 'look_at_obj_in_light': #5 in new_labels
263
+ list_of_highlevel_actions.append((obj_target, "PickupObject"))
264
+ #if toggle_target == "DeskLamp":
265
+ # print("Original toggle target was DeskLamp")
266
+ toggle_target = "FloorLamp"
267
+ list_of_highlevel_actions.append((toggle_target, "ToggleObjectOn" ))
268
+ categories_in_inst.append(obj_target)
269
+ categories_in_inst.append(toggle_target)
270
+
271
+ elif task_type == 'pick_clean_then_place_in_recep': #6 in new_labels
272
+ list_of_highlevel_actions.append((obj_target, "PickupObject"))
273
+ caution_pointers.append(len(list_of_highlevel_actions))
274
+ list_of_highlevel_actions.append(("SinkBasin", "PutObject")) #Sink or SinkBasin?
275
+ list_of_highlevel_actions.append(("Faucet", "ToggleObjectOn"))
276
+ list_of_highlevel_actions.append(("Faucet", "ToggleObjectOff"))
277
+ list_of_highlevel_actions.append((obj_target, "PickupObject"))
278
+ caution_pointers.append(len(list_of_highlevel_actions))
279
+ list_of_highlevel_actions = add_target(parent_target, "PutObject", list_of_highlevel_actions)
280
+ categories_in_inst.append(obj_target)
281
+ categories_in_inst.append("SinkBasin")
282
+ categories_in_inst.append("Faucet")
283
+ categories_in_inst.append(parent_target)
284
+ else:
285
+ raise Exception("Task type not one of 0, 1, 2, 3, 4, 5, 6!")
286
+
287
+ if sliced == 1:
288
+ if not(parent_target == "SinkBasin"):
289
+ categories_in_inst.append("SinkBasin")
290
+
291
+ #return [(goal_category, interaction), (goal_category, interaction), ...]
292
+ print("instruction goal is ", language_goal)
293
+ #list_of_highlevel_actions = [ ('Microwave', 'OpenObject'), ('Microwave', 'PutObject'), ('Microwave', 'CloseObject')]
294
+ #list_of_highlevel_actions = [('Microwave', 'OpenObject'), ('Microwave', 'PutObject'), ('Microwave', 'CloseObject'), ('Microwave', 'ToggleObjectOn'), ('Microwave', 'ToggleObjectOff'), ('Microwave', 'OpenObject'), ('Apple', 'PickupObject'), ('Microwave', 'CloseObject'), ('Fridge', 'OpenObject'), ('Fridge', 'PutObject'), ('Fridge', 'CloseObject')]
295
+ #categories_in_inst = ['Microwave', 'Fridge']
296
+ return list_of_highlevel_actions, categories_in_inst, second_object, caution_pointers
models/instructions_processed_LP/BERT/best_models/base.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2099e73c1ebdf22a9ae9ac122beaaf29da9e2cccf4eabf659919a20f42e9108b
3
+ size 438040777
models/instructions_processed_LP/BERT/best_models/mrecep.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d47ad0207d1672ff9a12af409a7681eb16be8faf7af7465c2e40804deae2fee2
3
+ size 438419145
models/instructions_processed_LP/BERT/best_models/object.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56b153e583afa38ac5a7369d72f67a99092678e29472995cf62c4fa48fefe0c3
3
+ size 438419145
models/instructions_processed_LP/BERT/best_models/parent.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c0b683d812d0fb0adaa6a92107c586e3dfb22bf20a43836df8b34a0ca317671
3
+ size 438148425
models/instructions_processed_LP/BERT/best_models/sliced.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:962b51b86dd1dcca9611ca0a1f6bc0b6a44d18e8de94cb07406459b5333ec749
3
+ size 438025417
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/correct_labels_dict_ppdl.p ADDED
Binary file (235 Bytes). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/correct_template_by_label_ppdl.p ADDED
Binary file (114 Bytes). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/obj2idx.p ADDED
Binary file (1.66 kB). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/obj2idx_new_split.p ADDED
Binary file (2.15 kB). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/recep2idx.p ADDED
Binary file (552 Bytes). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/recep2idx_new_split.p ADDED
Binary file (693 Bytes). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/template_by_label.p ADDED
Binary file (114 Bytes). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/toggle2idx.p ADDED
Binary file (213 Bytes). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/create_text_with_pddl_low_appended.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import argparse
4
+ import json
5
+ import pickle
6
+ import string
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('--data_path', type=str, default="/media/user/data/FILM/alfred_data_all/json_new", help="where to look for the generated data")
9
+ parser.add_argument('--split', type=str, default="tests_unseen")
10
+ parser.add_argument('-o','--output_name', type=str)
11
+ args = parser.parse_args()
12
+ data_path = args.data_path
13
+ split = args.split
14
+ exclude = set(string.punctuation)
15
+ result = dict()
16
+ traj_data_path = "/media/user/data/alfred_4.3.0/gen/dataset/Finished"
17
+ traj_data_path = "/media/user/data/FILM/alfred_data_all/json_2.1.0/valid_unseen"
18
+ task_to_path = dict()
19
+ desc_to_gt_params = dict() # 'wash the brown vegetable and put it on the counter’:
20
+ JSON_FILENAME = "traj_data.json"
21
+ for dir_name, _, _ in os.walk(traj_data_path):
22
+ if "trial_" in dir_name and (not "raw_images" in dir_name) and (not "pddl_states" in dir_name) and (not "video" in dir_name):
23
+ json_file = os.path.join(dir_name, JSON_FILENAME)
24
+ if not os.path.isfile(json_file):
25
+ continue
26
+ task = json_file.split('/')[-2]
27
+ task_to_path[task] = json_file
28
+ # print(task_to_path)
29
+
30
+ task_types = {"pick_cool_then_place_in_recep": 0, "pick_and_place_with_movable_recep": 1, "pick_and_place_simple": 2, "pick_two_obj_and_place": 3, "pick_heat_then_place_in_recep": 4, "look_at_obj_in_light": 5, "pick_clean_then_place_in_recep": 6}
31
+
32
+ # 'task_desc': task
33
+ # {'wash the brown vegetable and put it on the counter’: 'trial_T20230514_192232_882070',
34
+ # 'wash the brown vegetable and put it on the counter’: 'trial_T20230514_192232_882070', ...}
35
+ x_task = dict()
36
+
37
+ result['x'] = []; result['x_low'] = []
38
+ n = 0
39
+ d = 0
40
+ # with open('../../../../../alfred_data_small/splits/REALFRED_splits.json', 'r') as f:
41
+ # with open('alfred_data_small/splits/REALFRED_splits.json', 'r') as f:
42
+ with open('../../../../../alfred_data_small/splits/oct21.json', 'r') as f:
43
+
44
+
45
+ splits = json.load(f)
46
+ tasks = list()
47
+ for i in splits[split]:
48
+ task = i["task"]
49
+ if task not in tasks:
50
+ tasks.append(task)
51
+
52
+ for task in tasks:
53
+ with open(os.path.join(data_path,task,'pp','ann_0.json')) as f:
54
+ ann_0 = json.load(f)
55
+
56
+ anns = ann_0['turk_annotations']['anns'] # anns = [{"assignment_id", "high_descs", "task_desc"},{},{}]
57
+ for j in anns:
58
+ task_desc = j['task_desc']
59
+ if task_desc[-1] == '.':
60
+ task_desc = task_desc[:-1]
61
+ task_desc = task_desc.lower()
62
+ task_desc = ''.join(ch for ch in task_desc if ch not in exclude)
63
+ if task_desc in result['x']:
64
+ print(task, task_desc)
65
+ d += 1
66
+ result['x'].append(task_desc)
67
+
68
+ x_low = ''
69
+ for k in j['high_descs']:
70
+ if k[-1] == '.':
71
+ k = k[:-1]
72
+ k = k.lower()
73
+ k = ''.join(ch for ch in k if ch not in exclude)
74
+ x_low = x_low + k + '[SEP]'
75
+ x_low = x_low[:-6]
76
+ result['x_low'].append(x_low)
77
+ n += 1
78
+
79
+ # x_task
80
+ x_task[task_desc] = task
81
+ path = task_to_path[task]
82
+ with open(path) as f:
83
+ traj_data = json.load(f)
84
+ params = dict()
85
+ task_type = task_types[traj_data['task_type']]
86
+ params['task_type'] = task_type
87
+ if traj_data['pddl_params']['mrecep_target'] != "":
88
+ params['mrecep_target'] = traj_data['pddl_params']['mrecep_target']
89
+ else:
90
+ params['mrecep_target'] = None
91
+ params['object_target'] = traj_data['pddl_params']['object_target']
92
+ if traj_data['pddl_params']['parent_target'] != "":
93
+ params['parent_target'] = traj_data['pddl_params']['parent_target']
94
+ else:
95
+ params['parent_target'] = None
96
+ if traj_data['pddl_params']['object_sliced']:
97
+ params['sliced'] = 1
98
+ else:
99
+ params['sliced'] = 0
100
+
101
+ desc_to_gt_params[task_desc] = params
102
+
103
+
104
+
105
+
106
+ # pickle.dump(result, open(args.output_name + ".p", "wb"))
107
+ pickle.dump(result, open(args.split + '_text_with_ppdl_low_appended.p', 'wb'))
108
+ # pickle.dump(x_task, open(args.split + '_task_desc_to_task_id.p', 'wb'))
109
+ pickle.dump(desc_to_gt_params, open('../../../instruction2_params_test_unseen_noappended_GT.p', 'wb'))
110
+ print(f"num of annotations: {n}")
111
+ print(f"num of duplication: {d}")
112
+ print(f"num of not duplication: {n-d}")
models/instructions_processed_LP/BERT/data/alfred_data/create_text_with_pddl_low_appended_for_new_split.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import argparse
4
+ import json
5
+ import pickle
6
+ import string
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('--data_path', type=str, default="/media/user/data2/FILM/alfred_data_all/json_2.1.0", help="where to look for the generated data") # annotation path
9
+ parser.add_argument('--split', type=str, default="valid_unseen")
10
+ # parser.add_argument('-o','--output_name', type=str)
11
+ args = parser.parse_args()
12
+ data_path = args.data_path
13
+ split = args.split
14
+ exclude = set(string.punctuation)
15
+ result = dict()
16
+
17
+ # Path of 'traj_data.json' files that were created when the trajectory was generated
18
+ # traj_data_path = "/media/user/data/alfred_4.3.0/gen/dataset/Finished"
19
+
20
+
21
+ traj_data_path = "/media/user/data2/FILM/alfred_data_all/Re_json_2.1.0"
22
+ task_to_path = dict()
23
+ desc_to_gt_params = dict() # 'wash the brown vegetable and put it on the counter’:
24
+ JSON_FILENAME = "traj_data.json"
25
+ for dir_name, _, _ in os.walk(traj_data_path):
26
+ if "trial_" in dir_name and (not "raw_images" in dir_name) and (not "pddl_states" in dir_name) and (not "video" in dir_name):
27
+ json_file = os.path.join(dir_name, JSON_FILENAME)
28
+ if not os.path.isfile(json_file):
29
+ continue
30
+ task = json_file.split('/')[-2]
31
+ task_to_path[task] = json_file
32
+ # print(task_to_path)
33
+
34
+ task_types = {"pick_cool_then_place_in_recep": 0, "pick_and_place_with_movable_recep": 1, "pick_and_place_simple": 2, "pick_two_obj_and_place": 3, "pick_heat_then_place_in_recep": 4, "look_at_obj_in_light": 5, "pick_clean_then_place_in_recep": 6}
35
+
36
+ # 'task_desc': task
37
+ # {'wash the brown vegetable and put it on the counter’: 'trial_T20230514_192232_882070',
38
+ # 'wash the brown vegetable and put it on the counter’: 'trial_T20230514_192232_882070', ...}
39
+ x_task = dict()
40
+
41
+ if args.split in ['tests_seen', 'tests_unseen']:
42
+ result['x'] = []; result['x_low'] = []
43
+ if args.split in ['train', 'valid_seen', 'valid_unseen']:
44
+ result['x'] = []; result['y'] = []; result['s'] = []; result['mrecep_targets'] = []; result['object_targets'] = []; result['parent_targets'] = []; result['toggle_targets'] = []; result['x_low'] = []
45
+
46
+ n = 0
47
+ d = 0
48
+ # with open('../../../../../alfred_data_small/splits/REALFRED_splits.json', 'r') as f:
49
+ # with open('alfred_data_small/splits/REALFRED_splits.json', 'r') as f:
50
+ # with open('../../../../../alfred_data_small/splits/aug28.json', 'r') as f:
51
+
52
+ # with open('../../../../../alfred_data_small/splits/oct24.json', 'r') as f:
53
+ with open('/media/user/data2/FILM/alfred_data_small/splits/oct24.json', 'r') as f:
54
+ splits = json.load(f)
55
+ tasks = list()
56
+ for i in splits[split]:
57
+ task = i["task"]
58
+ print(task)
59
+ if task not in tasks:
60
+ tasks.append(task)
61
+ print("\n\n")
62
+ for task in tasks:
63
+ with open(os.path.join(traj_data_path,task,'pp','ann_0.json')) as f:
64
+ ann_0 = json.load(f)
65
+
66
+ # anns = ann_0['turk_annotations']['anns'] # anns = [{"assignment_id", "high_descs", "task_desc"},{},{}]
67
+
68
+ ##### Without annotation #####
69
+ # anns = [ann_0['template']]
70
+
71
+ #0514
72
+ anns = ann_0['turk_annotations']['anns']
73
+
74
+ for j in anns:
75
+ task_desc = j['task_desc']
76
+ if len(task_desc)> 0:
77
+ if task_desc[-1] == '.':
78
+ task_desc = task_desc[:-1]
79
+ task_desc = task_desc.lower()
80
+ task_desc = ''.join(ch for ch in task_desc if ch not in exclude)
81
+ if task_desc in result['x']:
82
+ print(task, task_desc)
83
+ print("====")
84
+ d += 1
85
+ result['x'].append(task_desc)
86
+
87
+ x_low = ''
88
+
89
+
90
+ ##### without template #####
91
+ # for k in j['high_descs']:
92
+
93
+ ##### Without annotation #####
94
+ # for k in j['high_descs'][:-1]:
95
+ for k in j['high_descs']:
96
+ if len(k)>0:
97
+ if k[-1] == '.':
98
+ k = k[:-1]
99
+ k = k.lower()
100
+ k = ''.join(ch for ch in k if ch not in exclude)
101
+ x_low = x_low + k + '[SEP]'
102
+ x_low = x_low[:-5]
103
+ result['x_low'].append(x_low)
104
+ n += 1
105
+
106
+
107
+ # Extract from traj_data.json (GT)
108
+ # x_task
109
+ x_task[task_desc] = task
110
+
111
+ task_trial = task[-29:]
112
+ path = task_to_path[task_trial]
113
+ with open(path) as f:
114
+ traj_data = json.load(f)
115
+
116
+ params = dict()
117
+ task_type = task_types[traj_data['task_type']]
118
+ params['task_type'] = task_type
119
+ if traj_data['pddl_params']['mrecep_target'] != "":
120
+ params['mrecep_target'] = traj_data['pddl_params']['mrecep_target']
121
+ else:
122
+ params['mrecep_target'] = None
123
+ params['object_target'] = traj_data['pddl_params']['object_target']
124
+ if traj_data['pddl_params']['parent_target'] != "":
125
+ params['parent_target'] = traj_data['pddl_params']['parent_target']
126
+ else:
127
+ params['parent_target'] = None
128
+ if traj_data['pddl_params']['object_sliced']:
129
+ params['sliced'] = 1
130
+ else:
131
+ params['sliced'] = 0
132
+ if traj_data['pddl_params']['toggle_target'] != "":
133
+ params['toggle_target'] = traj_data['pddl_params']['toggle_target']
134
+ else:
135
+ params['toggle_target'] = None
136
+
137
+ desc_to_gt_params[task_desc] = params
138
+
139
+ if args.split in ['train', 'valid_seen', 'valid_unseen']:
140
+ # obj2idx_new_split, recep2idx_new_split
141
+ import pickle
142
+ # obj2idx = pickle.load(open('alfred_dicts/obj2idx_new_split.p', 'rb'))
143
+ # recep2idx = pickle.load(open('alfred_dicts/recep2idx_new_split.p', 'rb'))
144
+ # toggle2idx = pickle.load(open('alfred_dicts/toggle2idx.p', 'rb'))
145
+ obj2idx = pickle.load(open('models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/obj2idx_new_split.p', 'rb'))
146
+ recep2idx = pickle.load(open('models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/recep2idx_new_split.p', 'rb'))
147
+ toggle2idx = pickle.load(open('models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/toggle2idx.p', 'rb'))
148
+
149
+ result['y'].append(task_type)
150
+
151
+ if traj_data['pddl_params']['object_sliced']:
152
+ result['s'].append(1)
153
+ else:
154
+ result['s'].append(0)
155
+
156
+ if traj_data['pddl_params']['mrecep_target'] != "":
157
+ result['mrecep_targets'].append(obj2idx[traj_data['pddl_params']['mrecep_target']])
158
+ else:
159
+ result['mrecep_targets'].append(obj2idx[None])
160
+
161
+ result['object_targets'].append(obj2idx[traj_data['pddl_params']['object_target']])
162
+
163
+ if traj_data['pddl_params']['parent_target'] != "":
164
+ result['parent_targets'].append(recep2idx[traj_data['pddl_params']['parent_target']])
165
+ else:
166
+ result['parent_targets'].append(recep2idx[None])
167
+
168
+ if traj_data['pddl_params']['toggle_target'] != "":
169
+ result['toggle_targets'].append(toggle2idx[traj_data['pddl_params']['toggle_target']])
170
+ else:
171
+ result['toggle_targets'].append(toggle2idx[None])
172
+
173
+
174
+ # pickle.dump(result, open(args.output_name + ".p", "wb"))
175
+ # pickle.dump(result, open(args.split + '_text_with_ppdl_low_appended_new_split_aug28_t.p', 'wb'))
176
+ # pickle.dump(x_task, open(args.split + '_task_desc_to_task_id_aug28_t.p', 'wb'))
177
+ # pickle.dump(desc_to_gt_params, open('../../../instruction2_params_' + args.split + '_new_split_GT_aug28_t.p', 'wb'))
178
+ # print(f"num of annotations: {n}")
179
+ # print(f"num of duplication: {d}")
180
+ # print(f"num of not duplication: {n-d}")
181
+
182
+ # 0514
183
+ # pickle.dump(result, open(args.split + '_text_with_ppdl_low_appended_new_split_oct24.p', 'wb'))
184
+ # pickle.dump(x_task, open(args.split + '_task_desc_to_task_id_oct24.p', 'wb'))
185
+ # pickle.dump(desc_to_gt_params, open('../../../instruction2_params_' + args.split + '_new_split_GT_oct24.p', 'wb'))
186
+ pickle.dump(result, open('models/instructions_processed_LP/BERT/data/alfred_data/'+args.split + '_text_with_ppdl_low_appended_new_split_oct24.p', 'wb'))
187
+ pickle.dump(x_task, open('models/instructions_processed_LP/BERT/data/alfred_data/'+args.split + '_task_desc_to_task_id_oct24.p', 'wb'))
188
+ pickle.dump(desc_to_gt_params, open('models/instructions_processed_LP/instruction2_params_' + args.split + '_new_split_GT_oct24.p', 'wb'))
189
+ print(f"num of annotations: {n}")
190
+ print(f"num of duplication: {d}")
191
+ print(f"num of not duplication: {n-d}")
models/instructions_processed_LP/BERT/data/alfred_data/tests_seen_task_desc_to_task_id_oct24.p ADDED
Binary file (155 kB). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/tests_seen_text_with_ppdl_low_appended_new_split_oct24.p ADDED
Binary file (767 kB). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_task_desc_to_task_id_oct24.p ADDED
Binary file (149 kB). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_text_with_ppdl_low_appended_GT.p ADDED
Binary file (61.9 kB). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_text_with_ppdl_low_appended_new_split_oct24.p ADDED
Binary file (756 kB). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_text_with_ppdl_low_appended_test_unseen_new_split_GT.p ADDED
Binary file (167 kB). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/train_task_desc_to_task_id_oct24.p ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d48d00106a760795caef5f1dbbe76162925fd450a43c2bfeb406453974b26f3
3
+ size 2003819
models/instructions_processed_LP/BERT/data/alfred_data/train_text_with_ppdl_low_appended_new_split_oct24.p ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33039ca77d04ab3e9806d7de5e687620144dbcffc42efcead167c77054e69307
3
+ size 9665585
models/instructions_processed_LP/BERT/data/alfred_data/valid_seen_task_desc_to_task_id_oct24.p ADDED
Binary file (114 kB). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/valid_seen_text_with_ppdl_low_appended_new_split_oct24.p ADDED
Binary file (480 kB). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/valid_unseen_task_desc_to_task_id_oct24.p ADDED
Binary file (111 kB). View file
 
models/instructions_processed_LP/BERT/data/alfred_data/valid_unseen_text_with_ppdl_low_appended_new_split_oct24.p ADDED
Binary file (461 kB). View file
 
models/instructions_processed_LP/BERT/end_to_end_outputs.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Wed Feb 3 23:52:39 2021
5
+
6
+ @author: soyeonmin
7
+ """
8
+
9
+ #Run base model (into templates) and then extract arguments
10
+
11
+ import random
12
+ import time
13
+ import torch
14
+ from torch import nn
15
+ import pickle
16
+ import glob
17
+ import argparse
18
+ import os
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('-sp','--split', type=str, choices=['valid_unseen', 'valid_seen', 'tests_seen', 'tests_unseen'], required=True)
21
+ parser.add_argument('-m','--model_saved_folder_name', type=str, required=True, default='best_models')
22
+ parser.add_argument('-o','--output_name', type=str, required=True)
23
+ parser.add_argument('--no_appended', action='store_true')
24
+
25
+ args = parser.parse_args()
26
+
27
+
28
+ def accuracy(y_pred, y_batch):
29
+ #y_pred has shape [batch, no_classes]
30
+ maxed = torch.max(y_pred, 1)
31
+ y_hat = maxed.indices
32
+ num_accurate = torch.sum((y_hat == y_batch).long())
33
+ train_accuracy = num_accurate/ y_hat.shape[0]
34
+ return train_accuracy.item()
35
+
36
+ def accurate_both(y_pred1, y_batch1, y_pred2, y_batch2):
37
+ #
38
+ maxed1 = torch.max(y_pred1, 1)
39
+ y_hat1 = maxed1.indices
40
+ #
41
+ maxed2 = torch.max(y_pred2, 1)
42
+ y_hat2 = maxed2.indices
43
+ #
44
+ num_both_accurate = torch.sum((y_hat1 == y_batch1).long() * (y_hat2 == y_batch2).long())
45
+ train_accuracy = num_both_accurate/ y_hat1.shape[0]
46
+ return train_accuracy.item()
47
+
48
+
49
+
50
+
51
+
52
+ #Load data
53
+ import pickle
54
+ val_set_unseen = pickle.load(open('data/alfred_data/'+ args.split + '_text_with_ppdl_low_appended_new_split_oct24.p', 'rb'))
55
+ desc2id = pickle.load(open('data/alfred_data/'+ args.split + '_task_desc_to_task_id_oct24.p', 'rb'))
56
+
57
+ # REALFRED new split
58
+ obj2idx = pickle.load(open('data/alfred_data/alfred_dicts/obj2idx_new_split.p', 'rb'))
59
+ recep2idx = pickle.load(open('data/alfred_data/alfred_dicts/recep2idx_new_split.p', 'rb'))
60
+
61
+ toggle2idx = pickle.load(open('data/alfred_data/alfred_dicts/toggle2idx.p', 'rb'))
62
+
63
+ idx2obj = {v:k for k, v in obj2idx.items()}
64
+ idx2recep = {v:k for k, v in recep2idx.items()}
65
+ idx2toggle = {v:k for k, v in toggle2idx.items()}
66
+
67
+ #These are based on new labels
68
+ task_to_label_mapping = {'mrecep_target':[1], 'object_target': [0,1,2,3,4,5,6],\
69
+ 'parent_target':[0,1,2,3,4,6], 'toggle_target': [5],
70
+ 'sliced':[0,1,2,3,4,5,6]}
71
+
72
+ #Set device
73
+ if torch.cuda.is_available():
74
+ device = torch.device('cuda')
75
+ else:
76
+ device = torch.device('cpu')
77
+
78
+
79
+ save_folder_name = args.model_saved_folder_name
80
+ #Base model
81
+ from transformers import BertForSequenceClassification
82
+ base_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=7).to(device)
83
+
84
+ base_model_name = 'base.pt'
85
+ base_model.load_state_dict(torch.load(os.path.join(save_folder_name, base_model_name)))
86
+ base_model.eval()
87
+
88
+ #Run data through base model and get label
89
+ from transformers import BertTokenizer
90
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
91
+
92
+ #Put validation data into the base_model
93
+ x_val_seen = val_set_unseen['x_low'] #; y_val_seen = val_set_seen['y']
94
+ if args.no_appended:
95
+ x_val_seen = val_set_unseen['x']
96
+
97
+
98
+ encoding_v_s = tokenizer(x_val_seen, return_tensors='pt', padding=True, truncation=True)
99
+ input_ids_val_seen = encoding_v_s['input_ids'].to(device)
100
+ attention_mask_val_seen = encoding_v_s['attention_mask'].to(device)
101
+
102
+ N= 10
103
+ y_hat_list_vs = []
104
+ if input_ids_val_seen.shape[0]%N!=0:
105
+ until = int(input_ids_val_seen.shape[0]/N)+1
106
+ else:
107
+ until = int(input_ids_val_seen.shape[0]/N)
108
+ for b in range(until):
109
+ input_ids_batch = input_ids_val_seen[N*b:N*(b+1)].to(device)
110
+ attention_mask_batch = attention_mask_val_seen[N*b:N*(b+1)].to(device)
111
+ outputs = base_model(input_ids_batch, attention_mask=attention_mask_batch)
112
+ predicted_templates = torch.max(outputs.logits, 1).indices
113
+ y_hat_list_vs += predicted_templates.cpu().numpy().tolist()
114
+
115
+ del outputs
116
+ del base_model
117
+ vs_idx2predicted_label = {i:y for i, y in enumerate(y_hat_list_vs)}
118
+
119
+ #Now extract the arguments
120
+ global c
121
+ c = 0
122
+ def get_prediction(classifier, N, input_ids, attention_mask):
123
+ y_hat_list = []
124
+ for b in range(int(input_ids.shape[0]/N)+1):
125
+ if b!=int(input_ids.shape[0]/N):
126
+ input_ids_batch = input_ids[N*b:N*(b+1)].to(device)
127
+ else:
128
+ input_ids_batch = input_ids[N*b:].to(device)
129
+ attention_mask_batch = attention_mask[N*b:N*(b+1)].to(device)
130
+
131
+ #outputs = base_model(input_ids_batch, attention_mask=attention_mask_batch, labels=labels_batch.view(1,-1))
132
+ # print('input_ids_batch: ', input_ids_batch.shape)
133
+ outputs = classifier(input_ids_batch, attention_mask=attention_mask_batch)
134
+ global c
135
+ c += 1
136
+ try:
137
+ print(c, torch.max(outputs.logits, 1).indices)
138
+ except:
139
+ print(outputs)
140
+ print(c)
141
+ # print(outputs.logits)
142
+ predicted_templates = torch.max(outputs.logits, 1).indices
143
+ del outputs
144
+ y_hat_list += predicted_templates.cpu().numpy().tolist()
145
+ return y_hat_list
146
+
147
+ x_val_seen_p = [str(int(vs_idx2predicted_label[i])) + ' ' + x for i, x in enumerate(x_val_seen)]
148
+
149
+ encoding_v_s = tokenizer(x_val_seen_p, return_tensors='pt', padding=True, truncation=True)
150
+ print(encoding_v_s)
151
+ input_ids_val_seen = encoding_v_s['input_ids'].to(device)
152
+ print(input_ids_val_seen.shape)
153
+ attention_mask_val_seen = encoding_v_s['attention_mask'].to(device)
154
+
155
+ parent_target_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(recep2idx)).to(device)
156
+ parent_target_classifier.load_state_dict(torch.load(os.path.join(save_folder_name, 'parent.pt')))
157
+ parent_outputs_hat = get_prediction(parent_target_classifier, 9, input_ids_val_seen, attention_mask_val_seen)
158
+ del parent_target_classifier
159
+ c = 0
160
+
161
+ object_target_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(obj2idx)).to(device)
162
+ object_target_classifier.load_state_dict(torch.load(os.path.join(save_folder_name, 'object.pt')))
163
+ object_outputs_hat = get_prediction(object_target_classifier, 9, input_ids_val_seen, attention_mask_val_seen)
164
+ del object_target_classifier
165
+ c = 0
166
+
167
+ sliced_target_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2).to(device)
168
+ sliced_target_classifier.load_state_dict(torch.load(os.path.join(save_folder_name, 'sliced.pt')))
169
+ sliced_outputs_hat = get_prediction(sliced_target_classifier, 9, input_ids_val_seen, attention_mask_val_seen)
170
+ del sliced_target_classifier
171
+ c = 0
172
+
173
+ mrecep_target_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(obj2idx)).to(device)
174
+ mrecep_target_classifier.load_state_dict(torch.load(os.path.join(save_folder_name, 'mrecep.pt')))
175
+ mrecep_outputs_hat = get_prediction(mrecep_target_classifier, 9, input_ids_val_seen, attention_mask_val_seen)
176
+ del mrecep_target_classifier
177
+
178
+ instructions= val_set_unseen['x_low']
179
+ if args.no_appended:
180
+ instructions= val_set_unseen['x']
181
+ instructions= val_set_unseen['x']
182
+ instruction2_params_test_unseen = {}
183
+ for i, instruction in enumerate(instructions):
184
+ task_type = vs_idx2predicted_label[i]
185
+ object_target = idx2obj[object_outputs_hat[i]]
186
+ if parent_outputs_hat == None:
187
+ parent_target = None
188
+ else:
189
+ parent_target = idx2recep[parent_outputs_hat[i]]
190
+ if mrecep_outputs_hat == None:
191
+ mrecep_target = None
192
+ else:
193
+ mrecep_target = idx2obj[mrecep_outputs_hat[i]]
194
+ sliced_target = sliced_outputs_hat[i]
195
+
196
+ if task_type == 5:
197
+ parent_target = None
198
+ # if task_type !=1:
199
+ # mrecep_target = None
200
+
201
+ instruction2_params_test_unseen[instruction] = {'task_type': task_type, \
202
+ 'mrecep_target': mrecep_target,\
203
+ 'sliced': sliced_target,\
204
+ 'object_target': object_target,\
205
+ 'parent_target': parent_target}
206
+
207
+
208
+ pickle.dump(instruction2_params_test_unseen, open("../" + args.output_name + ".p", "wb"))
209
+
210
+
models/instructions_processed_LP/BERT/train_bert_args.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Wed Feb 3 21:49:31 2021
5
+
6
+ @author: soyeonmin
7
+ """
8
+ from tensorboardX import SummaryWriter
9
+ from tqdm import trange
10
+ import random
11
+ import time
12
+ import torch
13
+ from torch import nn
14
+ import os
15
+ import glob
16
+ from collections import OrderedDict
17
+ import argparse
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument('-lr','--learning_rate', type=float, default=1e-5, help="learning rate")
20
+ parser.add_argument('-s','--seed', type=int, default=0, help="seed")
21
+ parser.add_argument('-l','--label', type=int, default=1, help="template")
22
+ parser.add_argument('-d','--decay', type=float, default=0.5, help="template")
23
+ parser.add_argument('-dt','--decay_term', type=int, default=50, help="template")
24
+ parser.add_argument('-t','--task', type=str, help="type in one of mrecep, object, parent, toggle, sliced")
25
+ parser.add_argument('-v','--verbose', type=int, default=0, help="print training output")
26
+ parser.add_argument('-model_type','--model_type', type=str, default='bert-base', help="one of roberta-large, roberta-base, bert-base, bert-large")
27
+ parser.add_argument('-load','--load', type=str, default='', help="one of roberta-large, roberta-base, bert-base, bert-large")
28
+ parser.add_argument('-no_divided_label','--no_divided_label', action='store_true')
29
+ parser.add_argument('--no_appended', action='store_true')
30
+
31
+
32
+ args = parser.parse_args()
33
+
34
+ #Set device
35
+ if torch.cuda.is_available():
36
+ device = torch.device('cuda')
37
+ else:
38
+ device = torch.device('cpu')
39
+
40
+ import pickle
41
+ template_by_label = pickle.load(open('data/alfred_data/alfred_dicts/template_by_label.p', 'rb'))
42
+ # train_set = pickle.load(open('data/alfred_data/train_text_with_ppdl_low_appended.p', 'rb'))
43
+ # val_set_seen = pickle.load(open('data/alfred_data/val_seen_text_with_ppdl_low_appended.p', 'rb'))
44
+ # val_set_unseen = pickle.load(open('data/alfred_data/val_unseen_text_with_ppdl_low_appended.p', 'rb'))
45
+
46
+ # obj2idx = pickle.load(open('data/alfred_data/alfred_dicts/obj2idx.p', 'rb'))
47
+ # recep2idx = pickle.load(open('data/alfred_data/alfred_dicts/recep2idx.p', 'rb'))
48
+ # train_set = pickle.load(open('data/alfred_data/train_text_with_ppdl_low_appended_new_split.p', 'rb'))
49
+ # val_set_seen = pickle.load(open('data/alfred_data/valid_seen_text_with_ppdl_low_appended_new_split.p', 'rb'))
50
+ # val_set_unseen = pickle.load(open('data/alfred_data/valid_unseen_text_with_ppdl_low_appended_new_split.p', 'rb'))
51
+
52
+ # 0514
53
+ train_set = pickle.load(open('data/alfred_data/train_text_with_ppdl_low_appended_new_split_oct24.p', 'rb'))
54
+ val_set_seen = pickle.load(open('data/alfred_data/valid_seen_text_with_ppdl_low_appended_new_split_oct24.p', 'rb'))
55
+ val_set_unseen = pickle.load(open('data/alfred_data/valid_unseen_text_with_ppdl_low_appended_new_split_oct24.p', 'rb'))
56
+
57
+ obj2idx = pickle.load(open('data/alfred_data/alfred_dicts/obj2idx_new_split.p', 'rb'))
58
+ recep2idx = pickle.load(open('data/alfred_data/alfred_dicts/recep2idx_new_split.p', 'rb'))
59
+ toggle2idx = pickle.load(open('data/alfred_data/alfred_dicts/toggle2idx.p', 'rb'))
60
+
61
+ assert args.task in ['mrecep', 'object', 'parent', 'toggle', 'sliced']
62
+ if args.task == 'mrecep':
63
+ label_arg = 'mrecep_targets'; num_labels = len(obj2idx)
64
+ elif args.task == 'object':
65
+ label_arg = 'object_targets'; num_labels = len(obj2idx)
66
+ elif args.task == 'parent':
67
+ label_arg = 'parent_targets'; num_labels = len(recep2idx)
68
+ elif args.task == 'toggle':
69
+ label_arg = 'toggle_targets'; num_labels = len(toggle2idx)
70
+ elif args.task == 'sliced':
71
+ label_arg = 's'; num_labels = 2
72
+
73
+
74
+
75
+
76
+
77
+ if args.model_type in ['bert-base', 'bert-large'] :
78
+ from transformers import BertTokenizer as Tokenizer
79
+ tok_type = args.model_type + '-uncased'
80
+ from transformers import BertForSequenceClassification as BertModel
81
+ elif args.model_type in ['roberta-base', 'roberta-large']:
82
+ from transformers import RobertaTokenizer as Tokenizer
83
+ tok_type = args.model_type
84
+ from transformers import RobertaForSequenceClassification as BertModel
85
+ tokenizer = Tokenizer.from_pretrained(tok_type)
86
+
87
+ torch.manual_seed(args.seed)
88
+ model = BertModel.from_pretrained(tok_type, num_labels=num_labels)
89
+
90
+ if len(args.load) >0:
91
+ sd = OrderedDict()
92
+ sd_ori = torch.load(args.load, map_location = 'cpu')
93
+ for k in sd_ori:
94
+ if 'classifier' in k:
95
+ if k =='classifier.weight':
96
+ sd[k] = model.classifier.weight
97
+ elif k =='classifier.bias':
98
+ sd[k] = model.classifier.bias
99
+ else:
100
+ sd[k] = sd_ori[k]
101
+ del sd_ori
102
+
103
+ model.load_state_dict(sd)
104
+
105
+ model = model.to(device)
106
+
107
+ ####################################################
108
+ ## 1. Prepare Data
109
+ ####################################################
110
+
111
+
112
+ from random import shuffle
113
+ import numpy as np
114
+ x_train = train_set['x_low']; y_train = train_set['y']
115
+ if args.no_appended:
116
+ x_train = train_set['x']
117
+ #if args.no_divided_label:
118
+ x_train = [str(y_train[i])+ ' ' + x for i, x in enumerate(x_train)]
119
+
120
+ if args.task == 'sliced' or args.no_divided_label:
121
+ x_ones = [i for i in range(len(x_train))]
122
+ else:
123
+ x_ones = [i for i, lab in enumerate(y_train) if lab == args.label]
124
+ random.seed(0)
125
+ shuffled = [i for i in x_ones]
126
+ shuffle(shuffled)
127
+ #x_train = np.array(train_set['x_low'])[shuffled]
128
+ x_train = np.array(x_train)[shuffled]
129
+ x_train = x_train.tolist()
130
+ encoding = tokenizer(x_train, return_tensors='pt', padding=True, truncation=True)
131
+ input_ids = encoding['input_ids'].to(device) #has shape [21025, 37]
132
+ attention_mask = encoding['attention_mask'].to(device) #also has shape [21025, 37]
133
+
134
+ label_train = np.array(train_set[label_arg])[shuffled]
135
+ label_train = torch.tensor(label_train).to(device)
136
+
137
+
138
+
139
+ x_val_seen = val_set_seen['x_low']; y_val_seen = val_set_seen['y']
140
+ if args.no_appended:
141
+ x_val_seen = val_set_seen['x']
142
+ x_val_seen = [str(y_val_seen[i])+ ' ' + x for i, x in enumerate(x_val_seen)]
143
+
144
+ if args.task == 'sliced' or args.no_divided_label:
145
+ random.seed(0)
146
+ x_ones_vs = random.sample(range(len(x_val_seen)), 100)
147
+ else:
148
+ x_ones_vs = [i for i, lab in enumerate(y_val_seen) if lab == args.label]
149
+ x_val_seen = np.array(x_val_seen)[x_ones_vs].tolist()
150
+ encoding_v_s = tokenizer(x_val_seen, return_tensors='pt', padding=True, truncation=True)
151
+ input_ids_val_seen = encoding_v_s['input_ids'].to(device)
152
+ attention_mask_val_seen = encoding_v_s['attention_mask'].to(device)
153
+
154
+ label_val_seen = [val_set_seen[label_arg][i] for i in x_ones_vs]
155
+ label_val_seen = torch.tensor(label_val_seen).to(device)
156
+
157
+
158
+
159
+ ####################################################
160
+ ## 2. Do training
161
+ ####################################################
162
+
163
+ model.train()
164
+ from transformers import AdamW
165
+ learning_rate = args.learning_rate
166
+ optimizer = AdamW(model.parameters(), lr=learning_rate)
167
+ if 'base' in args.model_type:
168
+ N= 64
169
+ else:
170
+ N = 32
171
+
172
+ def accuracy(y_pred, y_batch):
173
+ #y_pred has shape [batch, no_classes]
174
+ maxed = torch.max(y_pred, 1)
175
+ y_hat = maxed.indices
176
+ num_accurate = torch.sum((y_hat == y_batch).float())
177
+ train_accuracy = num_accurate/ y_hat.shape[0]
178
+ return train_accuracy.item()
179
+
180
+ def accurate_total(y_pred, y_batch):
181
+ #y_pred has shape [batch, no_classes]
182
+ maxed = torch.max(y_pred, 1)
183
+ y_hat = maxed.indices
184
+ num_accurate = torch.sum((y_hat == y_batch).float())
185
+ return num_accurate
186
+
187
+
188
+ if args.no_appended:
189
+ super_folder = 'saved_models_noappended_new_split_oct24/'
190
+ else:
191
+ super_folder = 'saved_models_appended_new_split_oct24/'
192
+
193
+ if args.task == 'sliced':
194
+ save_folder_name = args.task +'/' + args.model_type + '_lr_' + str(args.learning_rate) + 'seed_' + str(args.seed) + 'decay_' + str(args.decay) +'/'
195
+ else:
196
+ save_folder_name = args.task + str(args.label) + '/' + args.model_type+ '_lr_' + str(args.learning_rate) + 'seed_' + str(args.seed) + 'decay_' + str(args.decay) +'/'
197
+ if not os.path.exists(super_folder+'argument_models/' + save_folder_name):
198
+ os.makedirs(super_folder+'argument_models/' + save_folder_name)
199
+
200
+ accuracy_dictionary = {'training_loss': [], 'training':[], 'test':[]}
201
+ start_train_time = time.time()
202
+
203
+ summary_writer = SummaryWriter(log_dir=super_folder+'argument_models/' + save_folder_name)
204
+ train_iter = 0
205
+
206
+ for t in trange(50):
207
+ model.train()
208
+ if t>0 and (t+1)%args.decay_term ==0:
209
+ learning_rate *= args.decay
210
+ optimizer = AdamW(model.parameters(), lr=learning_rate)
211
+ avg_training_loss = 0.0
212
+ training_acc = 0.0
213
+ for b in trange(int(input_ids.shape[0]/N)):
214
+ input_ids_batch = input_ids[N*b:N*(b+1)].to(device)
215
+ labels_batch = label_train[N*b:N*(b+1)].to(device)
216
+ attention_mask_batch = attention_mask[N*b:N*(b+1)].to(device)
217
+ optimizer.zero_grad()
218
+ #forward pass
219
+ outputs = model(input_ids_batch, attention_mask=attention_mask_batch, labels=labels_batch.view(1,-1))
220
+ loss = outputs.loss
221
+ summary_writer.add_scalar('train/loss', loss, train_iter)
222
+
223
+
224
+ num_acc = accurate_total(y_pred=outputs.logits, y_batch=labels_batch.view(-1))
225
+ #if t ==0:
226
+ # print("loss at step ", t, " : ", loss.item())
227
+ loss.backward()
228
+ train_iter+=1
229
+ optimizer.step()
230
+ #
231
+ avg_training_loss += loss
232
+ training_acc += num_acc
233
+ avg_training_loss *= 1/int(input_ids.shape[0]/N)
234
+ #Print & Evaluate
235
+ if args.verbose:
236
+ print("loss at step ", t, " : ", loss.item())
237
+ print("training accuracy: ", training_acc/input_ids.shape[0])
238
+ #evaluate
239
+ with torch.no_grad():
240
+ model.eval()
241
+ outputs_val_seen = model(input_ids_val_seen, attention_mask=attention_mask_val_seen)
242
+ val_seen_acc = accuracy(y_pred=outputs_val_seen.logits, y_batch=label_val_seen.view(-1))
243
+ if args.verbose:
244
+ print("validation accuracy (seen): ", val_seen_acc)
245
+ del outputs_val_seen
246
+
247
+ model_name = 'epoch_' + str(t) + '.pt'
248
+ torch.save(model.state_dict(), super_folder+'argument_models/' + save_folder_name + model_name)
249
+ accuracy_dictionary['training_loss'].append(loss.item())
250
+ accuracy_dictionary['training'].append(training_acc/input_ids.shape[0])
251
+ accuracy_dictionary['test'].append(val_seen_acc)
252
+ summary_writer.add_scalar('valid/accuracy', val_seen_acc, train_iter)
253
+
254
+ print("Epoch " + str(t) + "\n")
255
+ print("training loss: " + str(accuracy_dictionary['training_loss'][t]))
256
+ print("training accuracy: " + str(accuracy_dictionary['training'][t]))
257
+ print("validation (seen) accuracy: " + str(accuracy_dictionary['test'][t]))
258
+
259
+ #Get the highest accuracy and delete the rest
260
+ highest_test = np.argwhere(accuracy_dictionary['test'] == np.amax(accuracy_dictionary['test']))
261
+ highest_test = highest_test.flatten().tolist()
262
+
263
+ training_acc_highest_h = np.argmax([accuracy_dictionary['training'][h].detach().cpu().numpy() for h in highest_test])
264
+ best_t = highest_test[training_acc_highest_h]
265
+ print("The best model is 'epoch_"+ str(best_t)+ '.pt')
266
+
267
+ #Delete every model except for best_t
268
+ file_path = super_folder+'argument_models/' + save_folder_name + "epoch_" + str(best_t) + ".pt"
269
+ if os.path.isfile(file_path):
270
+ for CleanUp in glob.glob(super_folder+'argument_models/' + save_folder_name + '*.pt'):
271
+ if not CleanUp.endswith(file_path):
272
+ os.remove(CleanUp)
273
+
274
+ #Save training/ test accuracy dictionary in a txt file
275
+ f = open(super_folder+ 'argument_models/' + save_folder_name + "training_log.txt", "w")
276
+ for t in range(100):
277
+
278
+ if t >= len(accuracy_dictionary['training_loss']):
279
+ break
280
+ if t == best_t:
281
+ f.write("===========================================\n")
282
+ # print("===========================================\n")
283
+ f.write("Epoch " + str(t) + "\n")
284
+ f.write("training loss: " + str(accuracy_dictionary['training_loss'][t]) + "\n")
285
+ f.write("training accuracy: " + str(accuracy_dictionary['training'][t]) + "\n")
286
+ f.write("validation (seen) accuracy: " + str(accuracy_dictionary['test'][t]) + "\n")
287
+
288
+ # print("Epoch " + str(t) + "\n")
289
+ # print("training loss: " + str(accuracy_dictionary['training_loss'][t]) + "\n")
290
+ # print("training accuracy: " + str(accuracy_dictionary['training'][t]) + "\n")
291
+ # print("validation (seen) accuracy: " + str(accuracy_dictionary['test'][t]) + "\n")
292
+ if t == best_t:
293
+ f.write("===========================================\n")
294
+ # print("===========================================\n")
295
+ f.close()
296
+
297
+
298
+
299
+ #Print training/ test accuracy for t
300
+ print("Saved and finished")
301
+
models/instructions_processed_LP/BERT/train_bert_base.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Wed Jan 20 11:41:44 2021
5
+
6
+ @author: soyeonmin
7
+ """
8
+
9
+ #Train BERT on ALFRED data
10
+ import tensorboard
11
+ from tensorboardX import SummaryWriter
12
+ import random
13
+ import time
14
+ import torch
15
+ from torch import nn
16
+ import os
17
+ from tqdm import trange
18
+ import argparse
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('-lr','--learning_rate', type=float, help="learning rate")
21
+ parser.add_argument('-model_type','--model_type', type=str, default='bert-base', help="one of roberta-large, roberta-base, bert-base, bert-large")
22
+ parser.add_argument('-task_desc_only','--task_desc_only', action='store_true', help="one of roberta-large, roberta-base, bert-base, bert-large")
23
+ parser.add_argument('--no_appended', action='store_true')
24
+
25
+
26
+ args = parser.parse_args()
27
+
28
+
29
+ ##########
30
+ #### 1. Data loading and processing
31
+ #########
32
+
33
+ #Set device
34
+ if torch.cuda.is_available():
35
+ device = torch.device('cuda')
36
+ else:
37
+ device = torch.device('cpu')
38
+
39
+
40
+ import pickle
41
+ directory = 'data/alfred_data/'
42
+ # train_set = pickle.load(open('data/alfred_data/train_text_with_ppdl_low_appended.p', 'rb'))
43
+ # val_set_seen = pickle.load(open('data/alfred_data/val_seen_text_with_ppdl_low_appended.p', 'rb'))
44
+ # val_set_unseen = pickle.load(open('data/alfred_data/val_unseen_text_with_ppdl_low_appended.p', 'rb'))
45
+
46
+ # train_set = pickle.load(open('data/alfred_data/train_text_with_ppdl_low_appended_new_split.p', 'rb'))
47
+ # val_set_seen = pickle.load(open('data/alfred_data/valid_seen_text_with_ppdl_low_appended_new_split.p', 'rb'))
48
+ # val_set_unseen = pickle.load(open('data/alfred_data/valid_unseen_text_with_ppdl_low_appended_new_split.p', 'rb'))
49
+
50
+ train_set = pickle.load(open('data/alfred_data/train_text_with_ppdl_low_appended_new_split_oct24.p', 'rb'))
51
+ val_set_seen = pickle.load(open('data/alfred_data/valid_seen_text_with_ppdl_low_appended_new_split_oct24.p', 'rb'))
52
+ val_set_unseen = pickle.load(open('data/alfred_data/valid_unseen_text_with_ppdl_low_appended_new_split_oct24.p', 'rb'))
53
+
54
+ #Now process like huggingface
55
+ if args.no_appended:
56
+ x_train = train_set['x']; y_train = train_set['y']
57
+ else:
58
+ x_train = train_set['x_low']; y_train = train_set['y']
59
+ labels = torch.tensor(y_train).to(device)
60
+
61
+ if args.model_type in ['bert-base', 'bert-large'] :
62
+ from transformers import BertTokenizer as Tokenizer
63
+ tok_type = args.model_type + '-uncased'
64
+ from transformers import BertForSequenceClassification as BertModel
65
+ elif args.model_type in ['roberta-base', 'roberta-large']:
66
+ from transformers import RobertaTokenizer as Tokenizer
67
+ tok_type = args.model_type
68
+ from transformers import RobertaForSequenceClassification as BertModel
69
+ tokenizer = Tokenizer.from_pretrained(tok_type)
70
+
71
+ model = BertModel.from_pretrained(tok_type, num_labels=7).to(device)
72
+ encoding = tokenizer(x_train, return_tensors='pt', padding=True, truncation=True)
73
+
74
+ input_ids = encoding['input_ids'].to(device) #has shape [21025, 37]
75
+ attention_mask = encoding['attention_mask'].to(device) #also has shape [21025, 37]
76
+
77
+
78
+ #Get input_ids, attention_mask for val_seen, val_unseen
79
+ x_val_seen = val_set_seen['x_low']; y_val_seen = val_set_seen['y']
80
+ if args.no_appended:
81
+ x_val_seen = val_set_seen['x']
82
+ #Just sample a few from x_val_seen, y_val_seen
83
+ random.seed(0)
84
+ sampled_val_seen = random.sample(range(len(x_val_seen)), 100)
85
+ x_val_seen = [x_val_seen[i] for i in sampled_val_seen]; y_val_seen = [y_val_seen[i] for i in sampled_val_seen]
86
+ labels_val_seen = torch.tensor(y_val_seen).to(device)
87
+ encoding_v_s = tokenizer(x_val_seen, return_tensors='pt', padding=True, truncation=True)
88
+ input_ids_val_seen = encoding_v_s['input_ids'].to(device)
89
+ attention_mask_val_seen = encoding_v_s['attention_mask'].to(device)
90
+
91
+ x_val_unseen = val_set_unseen['x_low']; y_val_unseen = val_set_unseen['y']
92
+ if args.no_appended:
93
+ x_val_unseen = val_set_unseen['x']
94
+ random.seed(1)
95
+ sampled_val_unseen = random.sample(range(len(x_val_unseen)), 100)
96
+ x_val_unseen = [x_val_unseen[i] for i in sampled_val_unseen]; y_val_unseen = [y_val_unseen[i] for i in sampled_val_unseen]
97
+ labels_val_unseen = torch.tensor(y_val_unseen).to(device)
98
+ encoding_v_us = tokenizer(x_val_unseen, return_tensors='pt', padding=True, truncation=True)
99
+ input_ids_val_unseen = encoding_v_us['input_ids'].to(device)
100
+ attention_mask_val_unseen = encoding_v_us['attention_mask'].to(device)
101
+
102
+
103
+
104
+ ##########
105
+ #### 2. Do training
106
+ #########
107
+
108
+ model.train()
109
+ from transformers import AdamW
110
+ optimizer = AdamW(model.parameters(), lr=args.learning_rate)
111
+ N = 16 #batch size
112
+ if args.no_appended:
113
+ N=64
114
+ save_folder_name = args.model_type + '_lr_' + str(args.learning_rate) +'/'
115
+ if args.no_appended:
116
+ super_folder = 'saved_models_noappended_new_split_oct24/'
117
+ else:
118
+ super_folder = 'saved_models_appended_new_split_oct24/'
119
+ if not os.path.exists(super_folder + save_folder_name):
120
+ os.makedirs(super_folder + save_folder_name)
121
+
122
+
123
+ def accuracy(y_pred, y_batch):
124
+ #y_pred has shape [batch, no_classes]
125
+ maxed = torch.max(y_pred, 1)
126
+ y_hat = maxed.indices
127
+ num_accurate = torch.sum((y_hat == y_batch).float())
128
+ train_accuracy = num_accurate/ y_hat.shape[0]
129
+ return train_accuracy.item()
130
+
131
+ def accurate_total(y_pred, y_batch):
132
+ #y_pred has shape [batch, no_classes]
133
+ maxed = torch.max(y_pred, 1)
134
+ y_hat = maxed.indices
135
+ num_accurate = torch.sum((y_hat == y_batch).float())
136
+ return num_accurate
137
+
138
+ start_train_time = time.time()
139
+ summary_writer = SummaryWriter(log_dir=super_folder + save_folder_name)
140
+ train_iter = 0
141
+ for t in trange(50):
142
+ model.train()
143
+ avg_training_loss = 0.0
144
+ training_acc = 0.0
145
+ for b in trange(int(input_ids.shape[0]/N)):
146
+ input_ids_batch = input_ids[N*b:N*(b+1)].to(device)
147
+ labels_batch = labels[N*b:N*(b+1)].to(device)
148
+ attention_mask_batch = attention_mask[N*b:N*(b+1)].to(device)
149
+ optimizer.zero_grad()
150
+ #forward pass
151
+ outputs = model(input_ids_batch, attention_mask=attention_mask_batch, labels=labels_batch.view(1,-1))
152
+ loss = outputs.loss
153
+ summary_writer.add_scalar('train/loss', loss, train_iter)
154
+
155
+
156
+
157
+ num_acc = accurate_total(y_pred=outputs.logits, y_batch=labels_batch.view(-1))
158
+ loss.backward()
159
+ train_iter+=1
160
+ optimizer.step()
161
+ #
162
+ avg_training_loss += loss
163
+ training_acc += num_acc
164
+ avg_training_loss *= 1/int(input_ids.shape[0]/N)
165
+ #Print & Evaluate
166
+ if t % 1 ==0:
167
+ with torch.no_grad():
168
+ model.eval()
169
+ print("loss at step ", t, " : ", loss.item())
170
+ print("training accuracy: ", training_acc/input_ids.shape[0])
171
+ #evaluate
172
+ outputs_val_seen = model(input_ids_val_seen, attention_mask=attention_mask_val_seen)
173
+ print("validation accuracy (seen): ", accuracy(y_pred=outputs_val_seen.logits, y_batch=labels_val_seen.view(-1)))
174
+ del outputs_val_seen
175
+ outputs_val_unseen = model(input_ids_val_unseen, attention_mask=attention_mask_val_unseen)
176
+ print("validation accuracy (unseen): ", accuracy(y_pred=outputs_val_unseen.logits, y_batch=labels_val_unseen.view(-1)))
177
+
178
+ model_name = 'epoch_' + str(t) + '.pt'
179
+ torch.save(model.state_dict(), super_folder + save_folder_name + model_name)
models/instructions_processed_LP/compare_BERT_pred_with_GT_oct24.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ splits = ["valid_seen","valid_unseen","tests_seen", "tests_unseen"]
4
+ split = "valid_unseen"
5
+
6
+ pred = pickle.load(open('instruction2_params_'+split+'_appended_new_split_oct24.p','rb'))
7
+ gt = pickle.load(open('instruction2_params_'+split+'_new_split_GT_oct24.p','rb'))
8
+
9
+ task_desc_list = list(pred.keys())
10
+ # gt = list(gt.items())
11
+ task_types = {"pick_cool_then_place_in_recep": 0, "pick_and_place_with_movable_recep": 1, "pick_and_place_simple": 2, "pick_two_obj_and_place": 3, "pick_heat_then_place_in_recep": 4, "look_at_obj_in_light": 5, "pick_clean_then_place_in_recep": 6}
12
+
13
+ c0=0; c1=0; c2=0; c3=0; c4=0; c5=0; c6=0
14
+ nt=0; nm=0; no=0; np=0; ns=0; nt=0
15
+ parent_pred = []; parent_gt=[]; parents = []
16
+ # num_of_tables_in_parent = 0
17
+ for i in range(len(pred)):
18
+ n = 0
19
+ task = task_desc_list[i]
20
+ print("##########################################################################")
21
+ print(task)
22
+ if pred[task]['task_type'] == gt[task]['task_type']:
23
+ nt += 1; n += 1
24
+ else:
25
+ print('Task type: ', pred[task]['task_type'], gt[task]['task_type'])
26
+
27
+ if pred[task]['mrecep_target'] == gt[task]['mrecep_target']:
28
+ nm += 1; n += 1
29
+ else:
30
+ print('mrecep: ', pred[task]['mrecep_target'], gt[task]['mrecep_target'])
31
+
32
+ if pred[task]['object_target'] == gt[task]['object_target']:
33
+ no += 1; n += 1
34
+ else:
35
+ print('object_target: ', pred[task]['object_target'], gt[task]['object_target'])
36
+
37
+ if pred[task]['parent_target'] == gt[task]['parent_target']:
38
+ np += 1; n += 1
39
+
40
+ else:
41
+ print('parent_target: ', pred[task]['parent_target'], gt[task]['parent_target'])
42
+ parent_pred.append(pred[task]['parent_target'])
43
+ parent_gt.append(gt[task]['parent_target'])
44
+ parents.append([pred[task]['parent_target'],gt[task]['parent_target']])
45
+
46
+ if pred[task]['sliced'] == gt[task]['sliced']:
47
+ ns += 1; n += 1
48
+ else:
49
+ print('sliced: ', pred[task]['sliced'], gt[task]['sliced'])
50
+ # if pred[task]['toggle_target'] == gt[task]['toggle_target']:
51
+ # nt += 1; n += 1
52
+ print("##########################################################################")
53
+
54
+ if n == 0:
55
+ c0 += 1
56
+ elif n == 1:
57
+ c1 += 1
58
+ elif n == 2:
59
+ c2 += 1
60
+ elif n == 3:
61
+ c3 += 1
62
+ elif n == 4:
63
+ c4 += 1
64
+ elif n == 5:
65
+ c5 += 1
66
+
67
+ print("################################################")
68
+ print(split)
69
+ print(f"Total number of task_desc: {len(pred), len(gt)}\n")
70
+ print(f"task_type accuracy: {round(nt/len(pred)*100, 2)}")
71
+ print(f"mrecep_target accuracy: {round(nm/len(pred)*100, 2)}")
72
+ print(f"object_target accuracy: {round(no/len(pred)*100, 2)}")
73
+ print(f"parent_target accuracy: {round(np/len(pred)*100, 2)}")
74
+ print(f"sliced accuracy: {round(ns/len(pred)*100, 2)}")
75
+ print()
76
+ print(f"0 Correct: {round(c0/len(pred)*100, 2)}")
77
+ print(f"1 Correct: {round(c1/len(pred)*100, 2)}")
78
+ print(f"2 Correct: {round(c2/len(pred)*100, 2)}")
79
+ print(f"3 Correct: {round(c3/len(pred)*100, 2)}")
80
+ print(f"4 Correct: {round(c4/len(pred)*100, 2)}")
81
+ print(f"5 Correct: {round(c5/len(pred)*100, 2)}")
82
+ print('\nTotal all correct accuracy: ', str(round(c5/len(pred)*100, 2)))
83
+ print("################################################")
models/instructions_processed_LP/instruction2_params_tests_seen_appended_new_split_oct24.p ADDED
Binary file (178 kB). View file
 
models/instructions_processed_LP/instruction2_params_tests_seen_new_split_GT_oct24.p ADDED
Binary file (244 kB). View file
 
models/instructions_processed_LP/instruction2_params_tests_unseen_appended_new_split_oct24.p ADDED
Binary file (173 kB). View file
 
models/instructions_processed_LP/instruction2_params_tests_unseen_new_split_GT_oct24.p ADDED
Binary file (233 kB). View file
 
models/instructions_processed_LP/instruction2_params_valid_seen_appended_new_split_oct24.p ADDED
Binary file (108 kB). View file
 
models/instructions_processed_LP/instruction2_params_valid_seen_new_split_GT_oct24.p ADDED
Binary file (146 kB). View file
 
models/instructions_processed_LP/instruction2_params_valid_unseen_appended_new_split_oct24.p ADDED
Binary file (103 kB). View file
 
models/instructions_processed_LP/instruction2_params_valid_unseen_new_split_GT_oct24.p ADDED
Binary file (140 kB). View file