Upload 40 files
Browse files- .gitattributes +2 -0
- models/instructions_processed_LP/ALFRED_task_helper.py +296 -0
- models/instructions_processed_LP/BERT/best_models/base.pt +3 -0
- models/instructions_processed_LP/BERT/best_models/mrecep.pt +3 -0
- models/instructions_processed_LP/BERT/best_models/object.pt +3 -0
- models/instructions_processed_LP/BERT/best_models/parent.pt +3 -0
- models/instructions_processed_LP/BERT/best_models/sliced.pt +3 -0
- models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/correct_labels_dict_ppdl.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/correct_template_by_label_ppdl.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/obj2idx.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/obj2idx_new_split.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/recep2idx.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/recep2idx_new_split.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/template_by_label.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/toggle2idx.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/create_text_with_pddl_low_appended.py +112 -0
- models/instructions_processed_LP/BERT/data/alfred_data/create_text_with_pddl_low_appended_for_new_split.py +191 -0
- models/instructions_processed_LP/BERT/data/alfred_data/tests_seen_task_desc_to_task_id_oct24.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/tests_seen_text_with_ppdl_low_appended_new_split_oct24.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_task_desc_to_task_id_oct24.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_text_with_ppdl_low_appended_GT.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_text_with_ppdl_low_appended_new_split_oct24.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_text_with_ppdl_low_appended_test_unseen_new_split_GT.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/train_task_desc_to_task_id_oct24.p +3 -0
- models/instructions_processed_LP/BERT/data/alfred_data/train_text_with_ppdl_low_appended_new_split_oct24.p +3 -0
- models/instructions_processed_LP/BERT/data/alfred_data/valid_seen_task_desc_to_task_id_oct24.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/valid_seen_text_with_ppdl_low_appended_new_split_oct24.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/valid_unseen_task_desc_to_task_id_oct24.p +0 -0
- models/instructions_processed_LP/BERT/data/alfred_data/valid_unseen_text_with_ppdl_low_appended_new_split_oct24.p +0 -0
- models/instructions_processed_LP/BERT/end_to_end_outputs.py +210 -0
- models/instructions_processed_LP/BERT/train_bert_args.py +301 -0
- models/instructions_processed_LP/BERT/train_bert_base.py +179 -0
- models/instructions_processed_LP/compare_BERT_pred_with_GT_oct24.py +83 -0
- models/instructions_processed_LP/instruction2_params_tests_seen_appended_new_split_oct24.p +0 -0
- models/instructions_processed_LP/instruction2_params_tests_seen_new_split_GT_oct24.p +0 -0
- models/instructions_processed_LP/instruction2_params_tests_unseen_appended_new_split_oct24.p +0 -0
- models/instructions_processed_LP/instruction2_params_tests_unseen_new_split_GT_oct24.p +0 -0
- models/instructions_processed_LP/instruction2_params_valid_seen_appended_new_split_oct24.p +0 -0
- models/instructions_processed_LP/instruction2_params_valid_seen_new_split_GT_oct24.p +0 -0
- models/instructions_processed_LP/instruction2_params_valid_unseen_appended_new_split_oct24.p +0 -0
- models/instructions_processed_LP/instruction2_params_valid_unseen_new_split_GT_oct24.p +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
models/instructions_processed_LP/BERT/data/alfred_data/train_task_desc_to_task_id_oct24.p filter=lfs diff=lfs merge=lfs -text
|
37 |
+
models/instructions_processed_LP/BERT/data/alfred_data/train_text_with_ppdl_low_appended_new_split_oct24.p filter=lfs diff=lfs merge=lfs -text
|
models/instructions_processed_LP/ALFRED_task_helper.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Sat Mar 27 16:49:38 2021
|
5 |
+
|
6 |
+
@author: soyeonmin
|
7 |
+
"""
|
8 |
+
import pickle
|
9 |
+
import alfred_utils.gen.constants as constants
|
10 |
+
import string
|
11 |
+
|
12 |
+
exclude = set(string.punctuation)
|
13 |
+
task_type_dict = {2: 'pick_and_place_simple',
|
14 |
+
5: 'look_at_obj_in_light',
|
15 |
+
1: 'pick_and_place_with_movable_recep',
|
16 |
+
3: 'pick_two_obj_and_place',
|
17 |
+
6: 'pick_clean_then_place_in_recep',
|
18 |
+
4: 'pick_heat_then_place_in_recep',
|
19 |
+
0: 'pick_cool_then_place_in_recep'}
|
20 |
+
|
21 |
+
|
22 |
+
def read_test_dict(test, appended, unseen):
|
23 |
+
if test:
|
24 |
+
if appended:
|
25 |
+
if unseen:
|
26 |
+
return pickle.load(open("models/instructions_processed_LP/instruction2_params_tests_unseen_new_split_GT_aug28.p", "rb"))
|
27 |
+
else:
|
28 |
+
return pickle.load(open("models/instructions_processed_LP/instruction2_params_tests_seen_new_split_GT_aug28.p", "rb"))
|
29 |
+
else:
|
30 |
+
if unseen:
|
31 |
+
# return pickle.load(open("models/instructions_processed_LP/instruction2_params_test_unseen_916_noappended.p", "rb"))
|
32 |
+
|
33 |
+
# REALFRED
|
34 |
+
return pickle.load(open("models/instructions_processed_LP/instruction2_params_test_unseen_noappended_new_split_bert_trained.p", "rb"))
|
35 |
+
|
36 |
+
else:
|
37 |
+
# return pickle.load(open("models/instructions_processed_LP/instruction2_params_test_seen_916_noappended.p", "rb"))
|
38 |
+
# REALFRED
|
39 |
+
return pickle.load(open("models/instructions_processed_LP/instruction2_params_test_seen_noappended_new_split_bert_trained.p", "rb"))
|
40 |
+
|
41 |
+
else:
|
42 |
+
if appended:
|
43 |
+
if unseen:
|
44 |
+
# return pickle.load(open("models/instructions_processed_LP/instruction2_params_val_unseen_appended.p", "rb"))
|
45 |
+
# return pickle.load(open("models/instructions_processed_LP/instruction2_params_valid_unseen_new_split_GT_aug28.p", "rb"))
|
46 |
+
|
47 |
+
# tests_unseen
|
48 |
+
return pickle.load(open("models/instructions_processed_LP/instruction2_params_valid_unseen_new_split_GT_aug28_t.p", "rb"))
|
49 |
+
else:
|
50 |
+
# # return pickle.load(open("models/instructions_processed_LP/instruction2_params_val_seen_appended.p", "rb"))
|
51 |
+
# return pickle.load(open("models/instructions_processed_LP/instruction2_params_valid_seen_new_split_GT_aug28.p", "rb"))
|
52 |
+
|
53 |
+
# tests_seen
|
54 |
+
return pickle.load(open("models/instructions_processed_LP/instruction2_params_valid_seen_new_split_GT_aug28_t.p", "rb"))
|
55 |
+
else:
|
56 |
+
if unseen:
|
57 |
+
# REALFRED
|
58 |
+
# return pickle.load(open("models/instructions_processed_LP/instruction2_params_val_unseen_916_noappended.p", "rb"))
|
59 |
+
# return pickle.load(open("models/instructions_processed_LP/instruction2_params_valid_unseen_noappended_new_split.p", "rb"))
|
60 |
+
return pickle.load(open("models/instructions_processed_LP/instruction2_params_valid_unseen_new_split_GT_aug28_t.p", "rb"))
|
61 |
+
else:
|
62 |
+
# return pickle.load(open("models/instructions_processed_LP/instruction2_params_val_seen_916_noappended.p", "rb"))
|
63 |
+
return pickle.load(open("models/instructions_processed_LP/instruction2_params_valid_seen_new_split_GT_aug28_t.p", "rb"))
|
64 |
+
|
65 |
+
def exist_or_no(string):
|
66 |
+
if string == '' or string == False:
|
67 |
+
return 0
|
68 |
+
else:
|
69 |
+
return 1
|
70 |
+
|
71 |
+
def none_or_str(string):
|
72 |
+
if string == '':
|
73 |
+
return None
|
74 |
+
else:
|
75 |
+
return string
|
76 |
+
|
77 |
+
def get_arguments_test(test_dict, instruction):
|
78 |
+
task_type, mrecep_target, object_target, parent_target, sliced = \
|
79 |
+
test_dict[instruction]['task_type'], test_dict[instruction]['mrecep_target'], test_dict[instruction]['object_target'], test_dict[instruction]['parent_target'],\
|
80 |
+
test_dict[instruction]['sliced']
|
81 |
+
|
82 |
+
if isinstance(task_type, int):
|
83 |
+
task_type = task_type_dict[task_type]
|
84 |
+
return instruction, task_type, mrecep_target, object_target, parent_target, sliced
|
85 |
+
|
86 |
+
|
87 |
+
def get_arguments(traj_data):
|
88 |
+
task_type = traj_data['task_type']
|
89 |
+
try:
|
90 |
+
# r_idx = traj_data['repeat_idx']
|
91 |
+
r_idx = traj_data['ann']['repeat_idx']
|
92 |
+
except:
|
93 |
+
r_idx = 0
|
94 |
+
language_goal_instr = traj_data['turk_annotations']['anns'][r_idx]['task_desc']
|
95 |
+
|
96 |
+
sliced = exist_or_no(traj_data['pddl_params']['object_sliced'])
|
97 |
+
mrecep_target = none_or_str(traj_data['pddl_params']['mrecep_target'])
|
98 |
+
object_target = none_or_str(traj_data['pddl_params']['object_target'])
|
99 |
+
parent_target = none_or_str(traj_data['pddl_params']['parent_target'])
|
100 |
+
#toggle_target = none_or_str(traj_data['pddl_params']['toggle_target'])
|
101 |
+
|
102 |
+
return language_goal_instr, task_type, mrecep_target, object_target, parent_target, sliced
|
103 |
+
|
104 |
+
def add_target(target, target_action, list_of_actions):
|
105 |
+
if target in [a for a in constants.OPENABLE_CLASS_LIST if not(a == 'Box')]:
|
106 |
+
list_of_actions.append((target, "OpenObject"))
|
107 |
+
list_of_actions.append((target, target_action))
|
108 |
+
if target in [a for a in constants.OPENABLE_CLASS_LIST if not(a == 'Box')]:
|
109 |
+
list_of_actions.append((target, "CloseObject"))
|
110 |
+
return list_of_actions
|
111 |
+
|
112 |
+
def determine_consecutive_interx(list_of_actions, previous_pointer, sliced=False):
|
113 |
+
returned, target_instance = False, None
|
114 |
+
if previous_pointer <= len(list_of_actions)-1:
|
115 |
+
if list_of_actions[previous_pointer][0] == list_of_actions[previous_pointer+1][0]:
|
116 |
+
returned = True
|
117 |
+
#target_instance = list_of_target_instance[-1] #previous target
|
118 |
+
target_instance = list_of_actions[previous_pointer][0]
|
119 |
+
#Micorwave or Fridge
|
120 |
+
elif list_of_actions[previous_pointer][1] == "OpenObject" and list_of_actions[previous_pointer+1][1] == "PickupObject":
|
121 |
+
returned = True
|
122 |
+
#target_instance = list_of_target_instance[0]
|
123 |
+
target_instance = list_of_actions[0][0]
|
124 |
+
if sliced:
|
125 |
+
#target_instance = list_of_target_instance[3]
|
126 |
+
target_instance = list_of_actions[3][0]
|
127 |
+
#Micorwave or Fridge
|
128 |
+
elif list_of_actions[previous_pointer][1] == "PickupObject" and list_of_actions[previous_pointer+1][1] == "CloseObject":
|
129 |
+
returned = True
|
130 |
+
#target_instance = list_of_target_instance[-2] #e.g. Fridge
|
131 |
+
target_instance = list_of_actions[previous_pointer-1][0]
|
132 |
+
#Faucet
|
133 |
+
elif list_of_actions[previous_pointer+1][0] == "Faucet" and list_of_actions[previous_pointer+1][1] in ["ToggleObjectOn", "ToggleObjectOff"]:
|
134 |
+
returned = True
|
135 |
+
target_instance = "Faucet"
|
136 |
+
#Pick up after faucet
|
137 |
+
elif list_of_actions[previous_pointer][0] == "Faucet" and list_of_actions[previous_pointer+1][1] == "PickupObject":
|
138 |
+
returned = True
|
139 |
+
#target_instance = list_of_target_instance[0]
|
140 |
+
target_instance = list_of_actions[0][0]
|
141 |
+
if sliced:
|
142 |
+
#target_instance = list_of_target_instance[3]
|
143 |
+
target_instance = list_of_actions[3][0]
|
144 |
+
return returned, target_instance
|
145 |
+
|
146 |
+
def get_list_of_highlevel_actions(traj_data, test=False, test_dict=None, args_nonsliced=False, appended=False):
|
147 |
+
if not(test):
|
148 |
+
language_goal, task_type, mrecep_target, obj_target, parent_target, sliced = get_arguments(traj_data)
|
149 |
+
if test:
|
150 |
+
r_idx = traj_data['ann']['repeat_idx']
|
151 |
+
instruction = traj_data['turk_annotations']['anns'][r_idx]['task_desc']
|
152 |
+
# aug28
|
153 |
+
|
154 |
+
#if appended:
|
155 |
+
instruction = instruction.lower()
|
156 |
+
instruction = ''.join(ch for ch in instruction if ch not in exclude)
|
157 |
+
language_goal, task_type, mrecep_target, obj_target, parent_target, sliced = get_arguments_test(test_dict, instruction)
|
158 |
+
|
159 |
+
#obj_target = 'Tomato'
|
160 |
+
#mrecep_target = "Plate"
|
161 |
+
if parent_target == "Sink":
|
162 |
+
parent_target = "SinkBasin"
|
163 |
+
if parent_target == "Bathtub":
|
164 |
+
parent_target = "BathtubBasin"
|
165 |
+
|
166 |
+
#Change to this after the sliced happens
|
167 |
+
if args_nonsliced:
|
168 |
+
if sliced == 1:
|
169 |
+
obj_target = obj_target +'Sliced'
|
170 |
+
#Map sliced as the same place in the map, but like "|SinkBasin" look at the objectid
|
171 |
+
|
172 |
+
|
173 |
+
categories_in_inst = []
|
174 |
+
list_of_highlevel_actions = []
|
175 |
+
second_object = []
|
176 |
+
caution_pointers = []
|
177 |
+
#obj_target = "Tomato"
|
178 |
+
|
179 |
+
#if sliced:
|
180 |
+
# obj_target = obj_target +'Sliced'
|
181 |
+
|
182 |
+
if sliced == 1:
|
183 |
+
list_of_highlevel_actions.append(("Knife", "PickupObject"))
|
184 |
+
list_of_highlevel_actions.append((obj_target, "SliceObject"))
|
185 |
+
caution_pointers.append(len(list_of_highlevel_actions))
|
186 |
+
list_of_highlevel_actions.append(("SinkBasin", "PutObject"))
|
187 |
+
categories_in_inst.append(obj_target)
|
188 |
+
|
189 |
+
if sliced:
|
190 |
+
obj_target = obj_target +'Sliced'
|
191 |
+
|
192 |
+
|
193 |
+
if task_type == 'pick_cool_then_place_in_recep': #0 in new_labels
|
194 |
+
list_of_highlevel_actions.append((obj_target, "PickupObject"))
|
195 |
+
caution_pointers.append(len(list_of_highlevel_actions))
|
196 |
+
list_of_highlevel_actions = add_target("Fridge", "PutObject", list_of_highlevel_actions)
|
197 |
+
list_of_highlevel_actions.append(("Fridge", "OpenObject"))
|
198 |
+
list_of_highlevel_actions.append((obj_target, "PickupObject"))
|
199 |
+
list_of_highlevel_actions.append(("Fridge", "CloseObject"))
|
200 |
+
caution_pointers.append(len(list_of_highlevel_actions))
|
201 |
+
list_of_highlevel_actions = add_target(parent_target, "PutObject", list_of_highlevel_actions)
|
202 |
+
categories_in_inst.append(obj_target)
|
203 |
+
categories_in_inst.append("Fridge")
|
204 |
+
categories_in_inst.append(parent_target)
|
205 |
+
|
206 |
+
elif task_type == 'pick_and_place_with_movable_recep': #1 in new_labels
|
207 |
+
list_of_highlevel_actions.append((obj_target, "PickupObject"))
|
208 |
+
caution_pointers.append(len(list_of_highlevel_actions))
|
209 |
+
list_of_highlevel_actions = add_target(mrecep_target, "PutObject", list_of_highlevel_actions)
|
210 |
+
list_of_highlevel_actions.append((mrecep_target, "PickupObject"))
|
211 |
+
caution_pointers.append(len(list_of_highlevel_actions))
|
212 |
+
list_of_highlevel_actions = add_target(parent_target, "PutObject", list_of_highlevel_actions)
|
213 |
+
categories_in_inst.append(obj_target)
|
214 |
+
categories_in_inst.append(mrecep_target)
|
215 |
+
categories_in_inst.append(parent_target)
|
216 |
+
|
217 |
+
elif task_type == 'pick_and_place_simple':#2 in new_labels
|
218 |
+
list_of_highlevel_actions.append((obj_target, "PickupObject"))
|
219 |
+
caution_pointers.append(len(list_of_highlevel_actions))
|
220 |
+
list_of_highlevel_actions = add_target(parent_target, "PutObject", list_of_highlevel_actions)
|
221 |
+
#list_of_highlevel_actions.append((parent_target, "PutObject"))
|
222 |
+
categories_in_inst.append(obj_target)
|
223 |
+
categories_in_inst.append(parent_target)
|
224 |
+
|
225 |
+
|
226 |
+
elif task_type == 'pick_heat_then_place_in_recep': #4 in new_labels
|
227 |
+
list_of_highlevel_actions.append((obj_target, "PickupObject"))
|
228 |
+
caution_pointers.append(len(list_of_highlevel_actions))
|
229 |
+
list_of_highlevel_actions = add_target("Microwave", "PutObject", list_of_highlevel_actions)
|
230 |
+
list_of_highlevel_actions.append(("Microwave", "ToggleObjectOn" ))
|
231 |
+
list_of_highlevel_actions.append(("Microwave", "ToggleObjectOff" ))
|
232 |
+
list_of_highlevel_actions.append(("Microwave", "OpenObject"))
|
233 |
+
list_of_highlevel_actions.append((obj_target, "PickupObject"))
|
234 |
+
list_of_highlevel_actions.append(("Microwave", "CloseObject"))
|
235 |
+
caution_pointers.append(len(list_of_highlevel_actions))
|
236 |
+
list_of_highlevel_actions = add_target(parent_target, "PutObject", list_of_highlevel_actions)
|
237 |
+
categories_in_inst.append(obj_target)
|
238 |
+
categories_in_inst.append("Microwave")
|
239 |
+
categories_in_inst.append(parent_target)
|
240 |
+
|
241 |
+
elif task_type == 'pick_two_obj_and_place': #3 in new_labels
|
242 |
+
list_of_highlevel_actions.append((obj_target, "PickupObject"))
|
243 |
+
caution_pointers.append(len(list_of_highlevel_actions))
|
244 |
+
list_of_highlevel_actions = add_target(parent_target, "PutObject", list_of_highlevel_actions)
|
245 |
+
if parent_target in constants.OPENABLE_CLASS_LIST:
|
246 |
+
second_object = [False] * 4
|
247 |
+
else:
|
248 |
+
second_object = [False] * 2
|
249 |
+
if sliced:
|
250 |
+
second_object = second_object + [False] * 3
|
251 |
+
caution_pointers.append(len(list_of_highlevel_actions))
|
252 |
+
list_of_highlevel_actions.append((obj_target, "PickupObject"))
|
253 |
+
#caution_pointers.append(len(list_of_highlevel_actions))
|
254 |
+
second_object.append(True)
|
255 |
+
caution_pointers.append(len(list_of_highlevel_actions))
|
256 |
+
list_of_highlevel_actions = add_target(parent_target, "PutObject", list_of_highlevel_actions)
|
257 |
+
second_object.append(False)
|
258 |
+
categories_in_inst.append(obj_target)
|
259 |
+
categories_in_inst.append(parent_target)
|
260 |
+
|
261 |
+
|
262 |
+
elif task_type == 'look_at_obj_in_light': #5 in new_labels
|
263 |
+
list_of_highlevel_actions.append((obj_target, "PickupObject"))
|
264 |
+
#if toggle_target == "DeskLamp":
|
265 |
+
# print("Original toggle target was DeskLamp")
|
266 |
+
toggle_target = "FloorLamp"
|
267 |
+
list_of_highlevel_actions.append((toggle_target, "ToggleObjectOn" ))
|
268 |
+
categories_in_inst.append(obj_target)
|
269 |
+
categories_in_inst.append(toggle_target)
|
270 |
+
|
271 |
+
elif task_type == 'pick_clean_then_place_in_recep': #6 in new_labels
|
272 |
+
list_of_highlevel_actions.append((obj_target, "PickupObject"))
|
273 |
+
caution_pointers.append(len(list_of_highlevel_actions))
|
274 |
+
list_of_highlevel_actions.append(("SinkBasin", "PutObject")) #Sink or SinkBasin?
|
275 |
+
list_of_highlevel_actions.append(("Faucet", "ToggleObjectOn"))
|
276 |
+
list_of_highlevel_actions.append(("Faucet", "ToggleObjectOff"))
|
277 |
+
list_of_highlevel_actions.append((obj_target, "PickupObject"))
|
278 |
+
caution_pointers.append(len(list_of_highlevel_actions))
|
279 |
+
list_of_highlevel_actions = add_target(parent_target, "PutObject", list_of_highlevel_actions)
|
280 |
+
categories_in_inst.append(obj_target)
|
281 |
+
categories_in_inst.append("SinkBasin")
|
282 |
+
categories_in_inst.append("Faucet")
|
283 |
+
categories_in_inst.append(parent_target)
|
284 |
+
else:
|
285 |
+
raise Exception("Task type not one of 0, 1, 2, 3, 4, 5, 6!")
|
286 |
+
|
287 |
+
if sliced == 1:
|
288 |
+
if not(parent_target == "SinkBasin"):
|
289 |
+
categories_in_inst.append("SinkBasin")
|
290 |
+
|
291 |
+
#return [(goal_category, interaction), (goal_category, interaction), ...]
|
292 |
+
print("instruction goal is ", language_goal)
|
293 |
+
#list_of_highlevel_actions = [ ('Microwave', 'OpenObject'), ('Microwave', 'PutObject'), ('Microwave', 'CloseObject')]
|
294 |
+
#list_of_highlevel_actions = [('Microwave', 'OpenObject'), ('Microwave', 'PutObject'), ('Microwave', 'CloseObject'), ('Microwave', 'ToggleObjectOn'), ('Microwave', 'ToggleObjectOff'), ('Microwave', 'OpenObject'), ('Apple', 'PickupObject'), ('Microwave', 'CloseObject'), ('Fridge', 'OpenObject'), ('Fridge', 'PutObject'), ('Fridge', 'CloseObject')]
|
295 |
+
#categories_in_inst = ['Microwave', 'Fridge']
|
296 |
+
return list_of_highlevel_actions, categories_in_inst, second_object, caution_pointers
|
models/instructions_processed_LP/BERT/best_models/base.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2099e73c1ebdf22a9ae9ac122beaaf29da9e2cccf4eabf659919a20f42e9108b
|
3 |
+
size 438040777
|
models/instructions_processed_LP/BERT/best_models/mrecep.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d47ad0207d1672ff9a12af409a7681eb16be8faf7af7465c2e40804deae2fee2
|
3 |
+
size 438419145
|
models/instructions_processed_LP/BERT/best_models/object.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:56b153e583afa38ac5a7369d72f67a99092678e29472995cf62c4fa48fefe0c3
|
3 |
+
size 438419145
|
models/instructions_processed_LP/BERT/best_models/parent.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1c0b683d812d0fb0adaa6a92107c586e3dfb22bf20a43836df8b34a0ca317671
|
3 |
+
size 438148425
|
models/instructions_processed_LP/BERT/best_models/sliced.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:962b51b86dd1dcca9611ca0a1f6bc0b6a44d18e8de94cb07406459b5333ec749
|
3 |
+
size 438025417
|
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/correct_labels_dict_ppdl.p
ADDED
Binary file (235 Bytes). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/correct_template_by_label_ppdl.p
ADDED
Binary file (114 Bytes). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/obj2idx.p
ADDED
Binary file (1.66 kB). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/obj2idx_new_split.p
ADDED
Binary file (2.15 kB). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/recep2idx.p
ADDED
Binary file (552 Bytes). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/recep2idx_new_split.p
ADDED
Binary file (693 Bytes). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/template_by_label.p
ADDED
Binary file (114 Bytes). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/toggle2idx.p
ADDED
Binary file (213 Bytes). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/create_text_with_pddl_low_appended.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import pickle
|
6 |
+
import string
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument('--data_path', type=str, default="/media/user/data/FILM/alfred_data_all/json_new", help="where to look for the generated data")
|
9 |
+
parser.add_argument('--split', type=str, default="tests_unseen")
|
10 |
+
parser.add_argument('-o','--output_name', type=str)
|
11 |
+
args = parser.parse_args()
|
12 |
+
data_path = args.data_path
|
13 |
+
split = args.split
|
14 |
+
exclude = set(string.punctuation)
|
15 |
+
result = dict()
|
16 |
+
traj_data_path = "/media/user/data/alfred_4.3.0/gen/dataset/Finished"
|
17 |
+
traj_data_path = "/media/user/data/FILM/alfred_data_all/json_2.1.0/valid_unseen"
|
18 |
+
task_to_path = dict()
|
19 |
+
desc_to_gt_params = dict() # 'wash the brown vegetable and put it on the counter’:
|
20 |
+
JSON_FILENAME = "traj_data.json"
|
21 |
+
for dir_name, _, _ in os.walk(traj_data_path):
|
22 |
+
if "trial_" in dir_name and (not "raw_images" in dir_name) and (not "pddl_states" in dir_name) and (not "video" in dir_name):
|
23 |
+
json_file = os.path.join(dir_name, JSON_FILENAME)
|
24 |
+
if not os.path.isfile(json_file):
|
25 |
+
continue
|
26 |
+
task = json_file.split('/')[-2]
|
27 |
+
task_to_path[task] = json_file
|
28 |
+
# print(task_to_path)
|
29 |
+
|
30 |
+
task_types = {"pick_cool_then_place_in_recep": 0, "pick_and_place_with_movable_recep": 1, "pick_and_place_simple": 2, "pick_two_obj_and_place": 3, "pick_heat_then_place_in_recep": 4, "look_at_obj_in_light": 5, "pick_clean_then_place_in_recep": 6}
|
31 |
+
|
32 |
+
# 'task_desc': task
|
33 |
+
# {'wash the brown vegetable and put it on the counter’: 'trial_T20230514_192232_882070',
|
34 |
+
# 'wash the brown vegetable and put it on the counter’: 'trial_T20230514_192232_882070', ...}
|
35 |
+
x_task = dict()
|
36 |
+
|
37 |
+
result['x'] = []; result['x_low'] = []
|
38 |
+
n = 0
|
39 |
+
d = 0
|
40 |
+
# with open('../../../../../alfred_data_small/splits/REALFRED_splits.json', 'r') as f:
|
41 |
+
# with open('alfred_data_small/splits/REALFRED_splits.json', 'r') as f:
|
42 |
+
with open('../../../../../alfred_data_small/splits/oct21.json', 'r') as f:
|
43 |
+
|
44 |
+
|
45 |
+
splits = json.load(f)
|
46 |
+
tasks = list()
|
47 |
+
for i in splits[split]:
|
48 |
+
task = i["task"]
|
49 |
+
if task not in tasks:
|
50 |
+
tasks.append(task)
|
51 |
+
|
52 |
+
for task in tasks:
|
53 |
+
with open(os.path.join(data_path,task,'pp','ann_0.json')) as f:
|
54 |
+
ann_0 = json.load(f)
|
55 |
+
|
56 |
+
anns = ann_0['turk_annotations']['anns'] # anns = [{"assignment_id", "high_descs", "task_desc"},{},{}]
|
57 |
+
for j in anns:
|
58 |
+
task_desc = j['task_desc']
|
59 |
+
if task_desc[-1] == '.':
|
60 |
+
task_desc = task_desc[:-1]
|
61 |
+
task_desc = task_desc.lower()
|
62 |
+
task_desc = ''.join(ch for ch in task_desc if ch not in exclude)
|
63 |
+
if task_desc in result['x']:
|
64 |
+
print(task, task_desc)
|
65 |
+
d += 1
|
66 |
+
result['x'].append(task_desc)
|
67 |
+
|
68 |
+
x_low = ''
|
69 |
+
for k in j['high_descs']:
|
70 |
+
if k[-1] == '.':
|
71 |
+
k = k[:-1]
|
72 |
+
k = k.lower()
|
73 |
+
k = ''.join(ch for ch in k if ch not in exclude)
|
74 |
+
x_low = x_low + k + '[SEP]'
|
75 |
+
x_low = x_low[:-6]
|
76 |
+
result['x_low'].append(x_low)
|
77 |
+
n += 1
|
78 |
+
|
79 |
+
# x_task
|
80 |
+
x_task[task_desc] = task
|
81 |
+
path = task_to_path[task]
|
82 |
+
with open(path) as f:
|
83 |
+
traj_data = json.load(f)
|
84 |
+
params = dict()
|
85 |
+
task_type = task_types[traj_data['task_type']]
|
86 |
+
params['task_type'] = task_type
|
87 |
+
if traj_data['pddl_params']['mrecep_target'] != "":
|
88 |
+
params['mrecep_target'] = traj_data['pddl_params']['mrecep_target']
|
89 |
+
else:
|
90 |
+
params['mrecep_target'] = None
|
91 |
+
params['object_target'] = traj_data['pddl_params']['object_target']
|
92 |
+
if traj_data['pddl_params']['parent_target'] != "":
|
93 |
+
params['parent_target'] = traj_data['pddl_params']['parent_target']
|
94 |
+
else:
|
95 |
+
params['parent_target'] = None
|
96 |
+
if traj_data['pddl_params']['object_sliced']:
|
97 |
+
params['sliced'] = 1
|
98 |
+
else:
|
99 |
+
params['sliced'] = 0
|
100 |
+
|
101 |
+
desc_to_gt_params[task_desc] = params
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
# pickle.dump(result, open(args.output_name + ".p", "wb"))
|
107 |
+
pickle.dump(result, open(args.split + '_text_with_ppdl_low_appended.p', 'wb'))
|
108 |
+
# pickle.dump(x_task, open(args.split + '_task_desc_to_task_id.p', 'wb'))
|
109 |
+
pickle.dump(desc_to_gt_params, open('../../../instruction2_params_test_unseen_noappended_GT.p', 'wb'))
|
110 |
+
print(f"num of annotations: {n}")
|
111 |
+
print(f"num of duplication: {d}")
|
112 |
+
print(f"num of not duplication: {n-d}")
|
models/instructions_processed_LP/BERT/data/alfred_data/create_text_with_pddl_low_appended_for_new_split.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import pickle
|
6 |
+
import string
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument('--data_path', type=str, default="/media/user/data2/FILM/alfred_data_all/json_2.1.0", help="where to look for the generated data") # annotation path
|
9 |
+
parser.add_argument('--split', type=str, default="valid_unseen")
|
10 |
+
# parser.add_argument('-o','--output_name', type=str)
|
11 |
+
args = parser.parse_args()
|
12 |
+
data_path = args.data_path
|
13 |
+
split = args.split
|
14 |
+
exclude = set(string.punctuation)
|
15 |
+
result = dict()
|
16 |
+
|
17 |
+
# Path of 'traj_data.json' files that were created when the trajectory was generated
|
18 |
+
# traj_data_path = "/media/user/data/alfred_4.3.0/gen/dataset/Finished"
|
19 |
+
|
20 |
+
|
21 |
+
traj_data_path = "/media/user/data2/FILM/alfred_data_all/Re_json_2.1.0"
|
22 |
+
task_to_path = dict()
|
23 |
+
desc_to_gt_params = dict() # 'wash the brown vegetable and put it on the counter’:
|
24 |
+
JSON_FILENAME = "traj_data.json"
|
25 |
+
for dir_name, _, _ in os.walk(traj_data_path):
|
26 |
+
if "trial_" in dir_name and (not "raw_images" in dir_name) and (not "pddl_states" in dir_name) and (not "video" in dir_name):
|
27 |
+
json_file = os.path.join(dir_name, JSON_FILENAME)
|
28 |
+
if not os.path.isfile(json_file):
|
29 |
+
continue
|
30 |
+
task = json_file.split('/')[-2]
|
31 |
+
task_to_path[task] = json_file
|
32 |
+
# print(task_to_path)
|
33 |
+
|
34 |
+
task_types = {"pick_cool_then_place_in_recep": 0, "pick_and_place_with_movable_recep": 1, "pick_and_place_simple": 2, "pick_two_obj_and_place": 3, "pick_heat_then_place_in_recep": 4, "look_at_obj_in_light": 5, "pick_clean_then_place_in_recep": 6}
|
35 |
+
|
36 |
+
# 'task_desc': task
|
37 |
+
# {'wash the brown vegetable and put it on the counter’: 'trial_T20230514_192232_882070',
|
38 |
+
# 'wash the brown vegetable and put it on the counter’: 'trial_T20230514_192232_882070', ...}
|
39 |
+
x_task = dict()
|
40 |
+
|
41 |
+
if args.split in ['tests_seen', 'tests_unseen']:
|
42 |
+
result['x'] = []; result['x_low'] = []
|
43 |
+
if args.split in ['train', 'valid_seen', 'valid_unseen']:
|
44 |
+
result['x'] = []; result['y'] = []; result['s'] = []; result['mrecep_targets'] = []; result['object_targets'] = []; result['parent_targets'] = []; result['toggle_targets'] = []; result['x_low'] = []
|
45 |
+
|
46 |
+
n = 0
|
47 |
+
d = 0
|
48 |
+
# with open('../../../../../alfred_data_small/splits/REALFRED_splits.json', 'r') as f:
|
49 |
+
# with open('alfred_data_small/splits/REALFRED_splits.json', 'r') as f:
|
50 |
+
# with open('../../../../../alfred_data_small/splits/aug28.json', 'r') as f:
|
51 |
+
|
52 |
+
# with open('../../../../../alfred_data_small/splits/oct24.json', 'r') as f:
|
53 |
+
with open('/media/user/data2/FILM/alfred_data_small/splits/oct24.json', 'r') as f:
|
54 |
+
splits = json.load(f)
|
55 |
+
tasks = list()
|
56 |
+
for i in splits[split]:
|
57 |
+
task = i["task"]
|
58 |
+
print(task)
|
59 |
+
if task not in tasks:
|
60 |
+
tasks.append(task)
|
61 |
+
print("\n\n")
|
62 |
+
for task in tasks:
|
63 |
+
with open(os.path.join(traj_data_path,task,'pp','ann_0.json')) as f:
|
64 |
+
ann_0 = json.load(f)
|
65 |
+
|
66 |
+
# anns = ann_0['turk_annotations']['anns'] # anns = [{"assignment_id", "high_descs", "task_desc"},{},{}]
|
67 |
+
|
68 |
+
##### Without annotation #####
|
69 |
+
# anns = [ann_0['template']]
|
70 |
+
|
71 |
+
#0514
|
72 |
+
anns = ann_0['turk_annotations']['anns']
|
73 |
+
|
74 |
+
for j in anns:
|
75 |
+
task_desc = j['task_desc']
|
76 |
+
if len(task_desc)> 0:
|
77 |
+
if task_desc[-1] == '.':
|
78 |
+
task_desc = task_desc[:-1]
|
79 |
+
task_desc = task_desc.lower()
|
80 |
+
task_desc = ''.join(ch for ch in task_desc if ch not in exclude)
|
81 |
+
if task_desc in result['x']:
|
82 |
+
print(task, task_desc)
|
83 |
+
print("====")
|
84 |
+
d += 1
|
85 |
+
result['x'].append(task_desc)
|
86 |
+
|
87 |
+
x_low = ''
|
88 |
+
|
89 |
+
|
90 |
+
##### without template #####
|
91 |
+
# for k in j['high_descs']:
|
92 |
+
|
93 |
+
##### Without annotation #####
|
94 |
+
# for k in j['high_descs'][:-1]:
|
95 |
+
for k in j['high_descs']:
|
96 |
+
if len(k)>0:
|
97 |
+
if k[-1] == '.':
|
98 |
+
k = k[:-1]
|
99 |
+
k = k.lower()
|
100 |
+
k = ''.join(ch for ch in k if ch not in exclude)
|
101 |
+
x_low = x_low + k + '[SEP]'
|
102 |
+
x_low = x_low[:-5]
|
103 |
+
result['x_low'].append(x_low)
|
104 |
+
n += 1
|
105 |
+
|
106 |
+
|
107 |
+
# Extract from traj_data.json (GT)
|
108 |
+
# x_task
|
109 |
+
x_task[task_desc] = task
|
110 |
+
|
111 |
+
task_trial = task[-29:]
|
112 |
+
path = task_to_path[task_trial]
|
113 |
+
with open(path) as f:
|
114 |
+
traj_data = json.load(f)
|
115 |
+
|
116 |
+
params = dict()
|
117 |
+
task_type = task_types[traj_data['task_type']]
|
118 |
+
params['task_type'] = task_type
|
119 |
+
if traj_data['pddl_params']['mrecep_target'] != "":
|
120 |
+
params['mrecep_target'] = traj_data['pddl_params']['mrecep_target']
|
121 |
+
else:
|
122 |
+
params['mrecep_target'] = None
|
123 |
+
params['object_target'] = traj_data['pddl_params']['object_target']
|
124 |
+
if traj_data['pddl_params']['parent_target'] != "":
|
125 |
+
params['parent_target'] = traj_data['pddl_params']['parent_target']
|
126 |
+
else:
|
127 |
+
params['parent_target'] = None
|
128 |
+
if traj_data['pddl_params']['object_sliced']:
|
129 |
+
params['sliced'] = 1
|
130 |
+
else:
|
131 |
+
params['sliced'] = 0
|
132 |
+
if traj_data['pddl_params']['toggle_target'] != "":
|
133 |
+
params['toggle_target'] = traj_data['pddl_params']['toggle_target']
|
134 |
+
else:
|
135 |
+
params['toggle_target'] = None
|
136 |
+
|
137 |
+
desc_to_gt_params[task_desc] = params
|
138 |
+
|
139 |
+
if args.split in ['train', 'valid_seen', 'valid_unseen']:
|
140 |
+
# obj2idx_new_split, recep2idx_new_split
|
141 |
+
import pickle
|
142 |
+
# obj2idx = pickle.load(open('alfred_dicts/obj2idx_new_split.p', 'rb'))
|
143 |
+
# recep2idx = pickle.load(open('alfred_dicts/recep2idx_new_split.p', 'rb'))
|
144 |
+
# toggle2idx = pickle.load(open('alfred_dicts/toggle2idx.p', 'rb'))
|
145 |
+
obj2idx = pickle.load(open('models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/obj2idx_new_split.p', 'rb'))
|
146 |
+
recep2idx = pickle.load(open('models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/recep2idx_new_split.p', 'rb'))
|
147 |
+
toggle2idx = pickle.load(open('models/instructions_processed_LP/BERT/data/alfred_data/alfred_dicts/toggle2idx.p', 'rb'))
|
148 |
+
|
149 |
+
result['y'].append(task_type)
|
150 |
+
|
151 |
+
if traj_data['pddl_params']['object_sliced']:
|
152 |
+
result['s'].append(1)
|
153 |
+
else:
|
154 |
+
result['s'].append(0)
|
155 |
+
|
156 |
+
if traj_data['pddl_params']['mrecep_target'] != "":
|
157 |
+
result['mrecep_targets'].append(obj2idx[traj_data['pddl_params']['mrecep_target']])
|
158 |
+
else:
|
159 |
+
result['mrecep_targets'].append(obj2idx[None])
|
160 |
+
|
161 |
+
result['object_targets'].append(obj2idx[traj_data['pddl_params']['object_target']])
|
162 |
+
|
163 |
+
if traj_data['pddl_params']['parent_target'] != "":
|
164 |
+
result['parent_targets'].append(recep2idx[traj_data['pddl_params']['parent_target']])
|
165 |
+
else:
|
166 |
+
result['parent_targets'].append(recep2idx[None])
|
167 |
+
|
168 |
+
if traj_data['pddl_params']['toggle_target'] != "":
|
169 |
+
result['toggle_targets'].append(toggle2idx[traj_data['pddl_params']['toggle_target']])
|
170 |
+
else:
|
171 |
+
result['toggle_targets'].append(toggle2idx[None])
|
172 |
+
|
173 |
+
|
174 |
+
# pickle.dump(result, open(args.output_name + ".p", "wb"))
|
175 |
+
# pickle.dump(result, open(args.split + '_text_with_ppdl_low_appended_new_split_aug28_t.p', 'wb'))
|
176 |
+
# pickle.dump(x_task, open(args.split + '_task_desc_to_task_id_aug28_t.p', 'wb'))
|
177 |
+
# pickle.dump(desc_to_gt_params, open('../../../instruction2_params_' + args.split + '_new_split_GT_aug28_t.p', 'wb'))
|
178 |
+
# print(f"num of annotations: {n}")
|
179 |
+
# print(f"num of duplication: {d}")
|
180 |
+
# print(f"num of not duplication: {n-d}")
|
181 |
+
|
182 |
+
# 0514
|
183 |
+
# pickle.dump(result, open(args.split + '_text_with_ppdl_low_appended_new_split_oct24.p', 'wb'))
|
184 |
+
# pickle.dump(x_task, open(args.split + '_task_desc_to_task_id_oct24.p', 'wb'))
|
185 |
+
# pickle.dump(desc_to_gt_params, open('../../../instruction2_params_' + args.split + '_new_split_GT_oct24.p', 'wb'))
|
186 |
+
pickle.dump(result, open('models/instructions_processed_LP/BERT/data/alfred_data/'+args.split + '_text_with_ppdl_low_appended_new_split_oct24.p', 'wb'))
|
187 |
+
pickle.dump(x_task, open('models/instructions_processed_LP/BERT/data/alfred_data/'+args.split + '_task_desc_to_task_id_oct24.p', 'wb'))
|
188 |
+
pickle.dump(desc_to_gt_params, open('models/instructions_processed_LP/instruction2_params_' + args.split + '_new_split_GT_oct24.p', 'wb'))
|
189 |
+
print(f"num of annotations: {n}")
|
190 |
+
print(f"num of duplication: {d}")
|
191 |
+
print(f"num of not duplication: {n-d}")
|
models/instructions_processed_LP/BERT/data/alfred_data/tests_seen_task_desc_to_task_id_oct24.p
ADDED
Binary file (155 kB). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/tests_seen_text_with_ppdl_low_appended_new_split_oct24.p
ADDED
Binary file (767 kB). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_task_desc_to_task_id_oct24.p
ADDED
Binary file (149 kB). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_text_with_ppdl_low_appended_GT.p
ADDED
Binary file (61.9 kB). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_text_with_ppdl_low_appended_new_split_oct24.p
ADDED
Binary file (756 kB). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/tests_unseen_text_with_ppdl_low_appended_test_unseen_new_split_GT.p
ADDED
Binary file (167 kB). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/train_task_desc_to_task_id_oct24.p
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0d48d00106a760795caef5f1dbbe76162925fd450a43c2bfeb406453974b26f3
|
3 |
+
size 2003819
|
models/instructions_processed_LP/BERT/data/alfred_data/train_text_with_ppdl_low_appended_new_split_oct24.p
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:33039ca77d04ab3e9806d7de5e687620144dbcffc42efcead167c77054e69307
|
3 |
+
size 9665585
|
models/instructions_processed_LP/BERT/data/alfred_data/valid_seen_task_desc_to_task_id_oct24.p
ADDED
Binary file (114 kB). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/valid_seen_text_with_ppdl_low_appended_new_split_oct24.p
ADDED
Binary file (480 kB). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/valid_unseen_task_desc_to_task_id_oct24.p
ADDED
Binary file (111 kB). View file
|
|
models/instructions_processed_LP/BERT/data/alfred_data/valid_unseen_text_with_ppdl_low_appended_new_split_oct24.p
ADDED
Binary file (461 kB). View file
|
|
models/instructions_processed_LP/BERT/end_to_end_outputs.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Wed Feb 3 23:52:39 2021
|
5 |
+
|
6 |
+
@author: soyeonmin
|
7 |
+
"""
|
8 |
+
|
9 |
+
#Run base model (into templates) and then extract arguments
|
10 |
+
|
11 |
+
import random
|
12 |
+
import time
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
import pickle
|
16 |
+
import glob
|
17 |
+
import argparse
|
18 |
+
import os
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument('-sp','--split', type=str, choices=['valid_unseen', 'valid_seen', 'tests_seen', 'tests_unseen'], required=True)
|
21 |
+
parser.add_argument('-m','--model_saved_folder_name', type=str, required=True, default='best_models')
|
22 |
+
parser.add_argument('-o','--output_name', type=str, required=True)
|
23 |
+
parser.add_argument('--no_appended', action='store_true')
|
24 |
+
|
25 |
+
args = parser.parse_args()
|
26 |
+
|
27 |
+
|
28 |
+
def accuracy(y_pred, y_batch):
|
29 |
+
#y_pred has shape [batch, no_classes]
|
30 |
+
maxed = torch.max(y_pred, 1)
|
31 |
+
y_hat = maxed.indices
|
32 |
+
num_accurate = torch.sum((y_hat == y_batch).long())
|
33 |
+
train_accuracy = num_accurate/ y_hat.shape[0]
|
34 |
+
return train_accuracy.item()
|
35 |
+
|
36 |
+
def accurate_both(y_pred1, y_batch1, y_pred2, y_batch2):
|
37 |
+
#
|
38 |
+
maxed1 = torch.max(y_pred1, 1)
|
39 |
+
y_hat1 = maxed1.indices
|
40 |
+
#
|
41 |
+
maxed2 = torch.max(y_pred2, 1)
|
42 |
+
y_hat2 = maxed2.indices
|
43 |
+
#
|
44 |
+
num_both_accurate = torch.sum((y_hat1 == y_batch1).long() * (y_hat2 == y_batch2).long())
|
45 |
+
train_accuracy = num_both_accurate/ y_hat1.shape[0]
|
46 |
+
return train_accuracy.item()
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
#Load data
|
53 |
+
import pickle
|
54 |
+
val_set_unseen = pickle.load(open('data/alfred_data/'+ args.split + '_text_with_ppdl_low_appended_new_split_oct24.p', 'rb'))
|
55 |
+
desc2id = pickle.load(open('data/alfred_data/'+ args.split + '_task_desc_to_task_id_oct24.p', 'rb'))
|
56 |
+
|
57 |
+
# REALFRED new split
|
58 |
+
obj2idx = pickle.load(open('data/alfred_data/alfred_dicts/obj2idx_new_split.p', 'rb'))
|
59 |
+
recep2idx = pickle.load(open('data/alfred_data/alfred_dicts/recep2idx_new_split.p', 'rb'))
|
60 |
+
|
61 |
+
toggle2idx = pickle.load(open('data/alfred_data/alfred_dicts/toggle2idx.p', 'rb'))
|
62 |
+
|
63 |
+
idx2obj = {v:k for k, v in obj2idx.items()}
|
64 |
+
idx2recep = {v:k for k, v in recep2idx.items()}
|
65 |
+
idx2toggle = {v:k for k, v in toggle2idx.items()}
|
66 |
+
|
67 |
+
#These are based on new labels
|
68 |
+
task_to_label_mapping = {'mrecep_target':[1], 'object_target': [0,1,2,3,4,5,6],\
|
69 |
+
'parent_target':[0,1,2,3,4,6], 'toggle_target': [5],
|
70 |
+
'sliced':[0,1,2,3,4,5,6]}
|
71 |
+
|
72 |
+
#Set device
|
73 |
+
if torch.cuda.is_available():
|
74 |
+
device = torch.device('cuda')
|
75 |
+
else:
|
76 |
+
device = torch.device('cpu')
|
77 |
+
|
78 |
+
|
79 |
+
save_folder_name = args.model_saved_folder_name
|
80 |
+
#Base model
|
81 |
+
from transformers import BertForSequenceClassification
|
82 |
+
base_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=7).to(device)
|
83 |
+
|
84 |
+
base_model_name = 'base.pt'
|
85 |
+
base_model.load_state_dict(torch.load(os.path.join(save_folder_name, base_model_name)))
|
86 |
+
base_model.eval()
|
87 |
+
|
88 |
+
#Run data through base model and get label
|
89 |
+
from transformers import BertTokenizer
|
90 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
91 |
+
|
92 |
+
#Put validation data into the base_model
|
93 |
+
x_val_seen = val_set_unseen['x_low'] #; y_val_seen = val_set_seen['y']
|
94 |
+
if args.no_appended:
|
95 |
+
x_val_seen = val_set_unseen['x']
|
96 |
+
|
97 |
+
|
98 |
+
encoding_v_s = tokenizer(x_val_seen, return_tensors='pt', padding=True, truncation=True)
|
99 |
+
input_ids_val_seen = encoding_v_s['input_ids'].to(device)
|
100 |
+
attention_mask_val_seen = encoding_v_s['attention_mask'].to(device)
|
101 |
+
|
102 |
+
N= 10
|
103 |
+
y_hat_list_vs = []
|
104 |
+
if input_ids_val_seen.shape[0]%N!=0:
|
105 |
+
until = int(input_ids_val_seen.shape[0]/N)+1
|
106 |
+
else:
|
107 |
+
until = int(input_ids_val_seen.shape[0]/N)
|
108 |
+
for b in range(until):
|
109 |
+
input_ids_batch = input_ids_val_seen[N*b:N*(b+1)].to(device)
|
110 |
+
attention_mask_batch = attention_mask_val_seen[N*b:N*(b+1)].to(device)
|
111 |
+
outputs = base_model(input_ids_batch, attention_mask=attention_mask_batch)
|
112 |
+
predicted_templates = torch.max(outputs.logits, 1).indices
|
113 |
+
y_hat_list_vs += predicted_templates.cpu().numpy().tolist()
|
114 |
+
|
115 |
+
del outputs
|
116 |
+
del base_model
|
117 |
+
vs_idx2predicted_label = {i:y for i, y in enumerate(y_hat_list_vs)}
|
118 |
+
|
119 |
+
#Now extract the arguments
|
120 |
+
global c
|
121 |
+
c = 0
|
122 |
+
def get_prediction(classifier, N, input_ids, attention_mask):
|
123 |
+
y_hat_list = []
|
124 |
+
for b in range(int(input_ids.shape[0]/N)+1):
|
125 |
+
if b!=int(input_ids.shape[0]/N):
|
126 |
+
input_ids_batch = input_ids[N*b:N*(b+1)].to(device)
|
127 |
+
else:
|
128 |
+
input_ids_batch = input_ids[N*b:].to(device)
|
129 |
+
attention_mask_batch = attention_mask[N*b:N*(b+1)].to(device)
|
130 |
+
|
131 |
+
#outputs = base_model(input_ids_batch, attention_mask=attention_mask_batch, labels=labels_batch.view(1,-1))
|
132 |
+
# print('input_ids_batch: ', input_ids_batch.shape)
|
133 |
+
outputs = classifier(input_ids_batch, attention_mask=attention_mask_batch)
|
134 |
+
global c
|
135 |
+
c += 1
|
136 |
+
try:
|
137 |
+
print(c, torch.max(outputs.logits, 1).indices)
|
138 |
+
except:
|
139 |
+
print(outputs)
|
140 |
+
print(c)
|
141 |
+
# print(outputs.logits)
|
142 |
+
predicted_templates = torch.max(outputs.logits, 1).indices
|
143 |
+
del outputs
|
144 |
+
y_hat_list += predicted_templates.cpu().numpy().tolist()
|
145 |
+
return y_hat_list
|
146 |
+
|
147 |
+
x_val_seen_p = [str(int(vs_idx2predicted_label[i])) + ' ' + x for i, x in enumerate(x_val_seen)]
|
148 |
+
|
149 |
+
encoding_v_s = tokenizer(x_val_seen_p, return_tensors='pt', padding=True, truncation=True)
|
150 |
+
print(encoding_v_s)
|
151 |
+
input_ids_val_seen = encoding_v_s['input_ids'].to(device)
|
152 |
+
print(input_ids_val_seen.shape)
|
153 |
+
attention_mask_val_seen = encoding_v_s['attention_mask'].to(device)
|
154 |
+
|
155 |
+
parent_target_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(recep2idx)).to(device)
|
156 |
+
parent_target_classifier.load_state_dict(torch.load(os.path.join(save_folder_name, 'parent.pt')))
|
157 |
+
parent_outputs_hat = get_prediction(parent_target_classifier, 9, input_ids_val_seen, attention_mask_val_seen)
|
158 |
+
del parent_target_classifier
|
159 |
+
c = 0
|
160 |
+
|
161 |
+
object_target_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(obj2idx)).to(device)
|
162 |
+
object_target_classifier.load_state_dict(torch.load(os.path.join(save_folder_name, 'object.pt')))
|
163 |
+
object_outputs_hat = get_prediction(object_target_classifier, 9, input_ids_val_seen, attention_mask_val_seen)
|
164 |
+
del object_target_classifier
|
165 |
+
c = 0
|
166 |
+
|
167 |
+
sliced_target_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2).to(device)
|
168 |
+
sliced_target_classifier.load_state_dict(torch.load(os.path.join(save_folder_name, 'sliced.pt')))
|
169 |
+
sliced_outputs_hat = get_prediction(sliced_target_classifier, 9, input_ids_val_seen, attention_mask_val_seen)
|
170 |
+
del sliced_target_classifier
|
171 |
+
c = 0
|
172 |
+
|
173 |
+
mrecep_target_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(obj2idx)).to(device)
|
174 |
+
mrecep_target_classifier.load_state_dict(torch.load(os.path.join(save_folder_name, 'mrecep.pt')))
|
175 |
+
mrecep_outputs_hat = get_prediction(mrecep_target_classifier, 9, input_ids_val_seen, attention_mask_val_seen)
|
176 |
+
del mrecep_target_classifier
|
177 |
+
|
178 |
+
instructions= val_set_unseen['x_low']
|
179 |
+
if args.no_appended:
|
180 |
+
instructions= val_set_unseen['x']
|
181 |
+
instructions= val_set_unseen['x']
|
182 |
+
instruction2_params_test_unseen = {}
|
183 |
+
for i, instruction in enumerate(instructions):
|
184 |
+
task_type = vs_idx2predicted_label[i]
|
185 |
+
object_target = idx2obj[object_outputs_hat[i]]
|
186 |
+
if parent_outputs_hat == None:
|
187 |
+
parent_target = None
|
188 |
+
else:
|
189 |
+
parent_target = idx2recep[parent_outputs_hat[i]]
|
190 |
+
if mrecep_outputs_hat == None:
|
191 |
+
mrecep_target = None
|
192 |
+
else:
|
193 |
+
mrecep_target = idx2obj[mrecep_outputs_hat[i]]
|
194 |
+
sliced_target = sliced_outputs_hat[i]
|
195 |
+
|
196 |
+
if task_type == 5:
|
197 |
+
parent_target = None
|
198 |
+
# if task_type !=1:
|
199 |
+
# mrecep_target = None
|
200 |
+
|
201 |
+
instruction2_params_test_unseen[instruction] = {'task_type': task_type, \
|
202 |
+
'mrecep_target': mrecep_target,\
|
203 |
+
'sliced': sliced_target,\
|
204 |
+
'object_target': object_target,\
|
205 |
+
'parent_target': parent_target}
|
206 |
+
|
207 |
+
|
208 |
+
pickle.dump(instruction2_params_test_unseen, open("../" + args.output_name + ".p", "wb"))
|
209 |
+
|
210 |
+
|
models/instructions_processed_LP/BERT/train_bert_args.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Wed Feb 3 21:49:31 2021
|
5 |
+
|
6 |
+
@author: soyeonmin
|
7 |
+
"""
|
8 |
+
from tensorboardX import SummaryWriter
|
9 |
+
from tqdm import trange
|
10 |
+
import random
|
11 |
+
import time
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
import os
|
15 |
+
import glob
|
16 |
+
from collections import OrderedDict
|
17 |
+
import argparse
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument('-lr','--learning_rate', type=float, default=1e-5, help="learning rate")
|
20 |
+
parser.add_argument('-s','--seed', type=int, default=0, help="seed")
|
21 |
+
parser.add_argument('-l','--label', type=int, default=1, help="template")
|
22 |
+
parser.add_argument('-d','--decay', type=float, default=0.5, help="template")
|
23 |
+
parser.add_argument('-dt','--decay_term', type=int, default=50, help="template")
|
24 |
+
parser.add_argument('-t','--task', type=str, help="type in one of mrecep, object, parent, toggle, sliced")
|
25 |
+
parser.add_argument('-v','--verbose', type=int, default=0, help="print training output")
|
26 |
+
parser.add_argument('-model_type','--model_type', type=str, default='bert-base', help="one of roberta-large, roberta-base, bert-base, bert-large")
|
27 |
+
parser.add_argument('-load','--load', type=str, default='', help="one of roberta-large, roberta-base, bert-base, bert-large")
|
28 |
+
parser.add_argument('-no_divided_label','--no_divided_label', action='store_true')
|
29 |
+
parser.add_argument('--no_appended', action='store_true')
|
30 |
+
|
31 |
+
|
32 |
+
args = parser.parse_args()
|
33 |
+
|
34 |
+
#Set device
|
35 |
+
if torch.cuda.is_available():
|
36 |
+
device = torch.device('cuda')
|
37 |
+
else:
|
38 |
+
device = torch.device('cpu')
|
39 |
+
|
40 |
+
import pickle
|
41 |
+
template_by_label = pickle.load(open('data/alfred_data/alfred_dicts/template_by_label.p', 'rb'))
|
42 |
+
# train_set = pickle.load(open('data/alfred_data/train_text_with_ppdl_low_appended.p', 'rb'))
|
43 |
+
# val_set_seen = pickle.load(open('data/alfred_data/val_seen_text_with_ppdl_low_appended.p', 'rb'))
|
44 |
+
# val_set_unseen = pickle.load(open('data/alfred_data/val_unseen_text_with_ppdl_low_appended.p', 'rb'))
|
45 |
+
|
46 |
+
# obj2idx = pickle.load(open('data/alfred_data/alfred_dicts/obj2idx.p', 'rb'))
|
47 |
+
# recep2idx = pickle.load(open('data/alfred_data/alfred_dicts/recep2idx.p', 'rb'))
|
48 |
+
# train_set = pickle.load(open('data/alfred_data/train_text_with_ppdl_low_appended_new_split.p', 'rb'))
|
49 |
+
# val_set_seen = pickle.load(open('data/alfred_data/valid_seen_text_with_ppdl_low_appended_new_split.p', 'rb'))
|
50 |
+
# val_set_unseen = pickle.load(open('data/alfred_data/valid_unseen_text_with_ppdl_low_appended_new_split.p', 'rb'))
|
51 |
+
|
52 |
+
# 0514
|
53 |
+
train_set = pickle.load(open('data/alfred_data/train_text_with_ppdl_low_appended_new_split_oct24.p', 'rb'))
|
54 |
+
val_set_seen = pickle.load(open('data/alfred_data/valid_seen_text_with_ppdl_low_appended_new_split_oct24.p', 'rb'))
|
55 |
+
val_set_unseen = pickle.load(open('data/alfred_data/valid_unseen_text_with_ppdl_low_appended_new_split_oct24.p', 'rb'))
|
56 |
+
|
57 |
+
obj2idx = pickle.load(open('data/alfred_data/alfred_dicts/obj2idx_new_split.p', 'rb'))
|
58 |
+
recep2idx = pickle.load(open('data/alfred_data/alfred_dicts/recep2idx_new_split.p', 'rb'))
|
59 |
+
toggle2idx = pickle.load(open('data/alfred_data/alfred_dicts/toggle2idx.p', 'rb'))
|
60 |
+
|
61 |
+
assert args.task in ['mrecep', 'object', 'parent', 'toggle', 'sliced']
|
62 |
+
if args.task == 'mrecep':
|
63 |
+
label_arg = 'mrecep_targets'; num_labels = len(obj2idx)
|
64 |
+
elif args.task == 'object':
|
65 |
+
label_arg = 'object_targets'; num_labels = len(obj2idx)
|
66 |
+
elif args.task == 'parent':
|
67 |
+
label_arg = 'parent_targets'; num_labels = len(recep2idx)
|
68 |
+
elif args.task == 'toggle':
|
69 |
+
label_arg = 'toggle_targets'; num_labels = len(toggle2idx)
|
70 |
+
elif args.task == 'sliced':
|
71 |
+
label_arg = 's'; num_labels = 2
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
if args.model_type in ['bert-base', 'bert-large'] :
|
78 |
+
from transformers import BertTokenizer as Tokenizer
|
79 |
+
tok_type = args.model_type + '-uncased'
|
80 |
+
from transformers import BertForSequenceClassification as BertModel
|
81 |
+
elif args.model_type in ['roberta-base', 'roberta-large']:
|
82 |
+
from transformers import RobertaTokenizer as Tokenizer
|
83 |
+
tok_type = args.model_type
|
84 |
+
from transformers import RobertaForSequenceClassification as BertModel
|
85 |
+
tokenizer = Tokenizer.from_pretrained(tok_type)
|
86 |
+
|
87 |
+
torch.manual_seed(args.seed)
|
88 |
+
model = BertModel.from_pretrained(tok_type, num_labels=num_labels)
|
89 |
+
|
90 |
+
if len(args.load) >0:
|
91 |
+
sd = OrderedDict()
|
92 |
+
sd_ori = torch.load(args.load, map_location = 'cpu')
|
93 |
+
for k in sd_ori:
|
94 |
+
if 'classifier' in k:
|
95 |
+
if k =='classifier.weight':
|
96 |
+
sd[k] = model.classifier.weight
|
97 |
+
elif k =='classifier.bias':
|
98 |
+
sd[k] = model.classifier.bias
|
99 |
+
else:
|
100 |
+
sd[k] = sd_ori[k]
|
101 |
+
del sd_ori
|
102 |
+
|
103 |
+
model.load_state_dict(sd)
|
104 |
+
|
105 |
+
model = model.to(device)
|
106 |
+
|
107 |
+
####################################################
|
108 |
+
## 1. Prepare Data
|
109 |
+
####################################################
|
110 |
+
|
111 |
+
|
112 |
+
from random import shuffle
|
113 |
+
import numpy as np
|
114 |
+
x_train = train_set['x_low']; y_train = train_set['y']
|
115 |
+
if args.no_appended:
|
116 |
+
x_train = train_set['x']
|
117 |
+
#if args.no_divided_label:
|
118 |
+
x_train = [str(y_train[i])+ ' ' + x for i, x in enumerate(x_train)]
|
119 |
+
|
120 |
+
if args.task == 'sliced' or args.no_divided_label:
|
121 |
+
x_ones = [i for i in range(len(x_train))]
|
122 |
+
else:
|
123 |
+
x_ones = [i for i, lab in enumerate(y_train) if lab == args.label]
|
124 |
+
random.seed(0)
|
125 |
+
shuffled = [i for i in x_ones]
|
126 |
+
shuffle(shuffled)
|
127 |
+
#x_train = np.array(train_set['x_low'])[shuffled]
|
128 |
+
x_train = np.array(x_train)[shuffled]
|
129 |
+
x_train = x_train.tolist()
|
130 |
+
encoding = tokenizer(x_train, return_tensors='pt', padding=True, truncation=True)
|
131 |
+
input_ids = encoding['input_ids'].to(device) #has shape [21025, 37]
|
132 |
+
attention_mask = encoding['attention_mask'].to(device) #also has shape [21025, 37]
|
133 |
+
|
134 |
+
label_train = np.array(train_set[label_arg])[shuffled]
|
135 |
+
label_train = torch.tensor(label_train).to(device)
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
x_val_seen = val_set_seen['x_low']; y_val_seen = val_set_seen['y']
|
140 |
+
if args.no_appended:
|
141 |
+
x_val_seen = val_set_seen['x']
|
142 |
+
x_val_seen = [str(y_val_seen[i])+ ' ' + x for i, x in enumerate(x_val_seen)]
|
143 |
+
|
144 |
+
if args.task == 'sliced' or args.no_divided_label:
|
145 |
+
random.seed(0)
|
146 |
+
x_ones_vs = random.sample(range(len(x_val_seen)), 100)
|
147 |
+
else:
|
148 |
+
x_ones_vs = [i for i, lab in enumerate(y_val_seen) if lab == args.label]
|
149 |
+
x_val_seen = np.array(x_val_seen)[x_ones_vs].tolist()
|
150 |
+
encoding_v_s = tokenizer(x_val_seen, return_tensors='pt', padding=True, truncation=True)
|
151 |
+
input_ids_val_seen = encoding_v_s['input_ids'].to(device)
|
152 |
+
attention_mask_val_seen = encoding_v_s['attention_mask'].to(device)
|
153 |
+
|
154 |
+
label_val_seen = [val_set_seen[label_arg][i] for i in x_ones_vs]
|
155 |
+
label_val_seen = torch.tensor(label_val_seen).to(device)
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
####################################################
|
160 |
+
## 2. Do training
|
161 |
+
####################################################
|
162 |
+
|
163 |
+
model.train()
|
164 |
+
from transformers import AdamW
|
165 |
+
learning_rate = args.learning_rate
|
166 |
+
optimizer = AdamW(model.parameters(), lr=learning_rate)
|
167 |
+
if 'base' in args.model_type:
|
168 |
+
N= 64
|
169 |
+
else:
|
170 |
+
N = 32
|
171 |
+
|
172 |
+
def accuracy(y_pred, y_batch):
|
173 |
+
#y_pred has shape [batch, no_classes]
|
174 |
+
maxed = torch.max(y_pred, 1)
|
175 |
+
y_hat = maxed.indices
|
176 |
+
num_accurate = torch.sum((y_hat == y_batch).float())
|
177 |
+
train_accuracy = num_accurate/ y_hat.shape[0]
|
178 |
+
return train_accuracy.item()
|
179 |
+
|
180 |
+
def accurate_total(y_pred, y_batch):
|
181 |
+
#y_pred has shape [batch, no_classes]
|
182 |
+
maxed = torch.max(y_pred, 1)
|
183 |
+
y_hat = maxed.indices
|
184 |
+
num_accurate = torch.sum((y_hat == y_batch).float())
|
185 |
+
return num_accurate
|
186 |
+
|
187 |
+
|
188 |
+
if args.no_appended:
|
189 |
+
super_folder = 'saved_models_noappended_new_split_oct24/'
|
190 |
+
else:
|
191 |
+
super_folder = 'saved_models_appended_new_split_oct24/'
|
192 |
+
|
193 |
+
if args.task == 'sliced':
|
194 |
+
save_folder_name = args.task +'/' + args.model_type + '_lr_' + str(args.learning_rate) + 'seed_' + str(args.seed) + 'decay_' + str(args.decay) +'/'
|
195 |
+
else:
|
196 |
+
save_folder_name = args.task + str(args.label) + '/' + args.model_type+ '_lr_' + str(args.learning_rate) + 'seed_' + str(args.seed) + 'decay_' + str(args.decay) +'/'
|
197 |
+
if not os.path.exists(super_folder+'argument_models/' + save_folder_name):
|
198 |
+
os.makedirs(super_folder+'argument_models/' + save_folder_name)
|
199 |
+
|
200 |
+
accuracy_dictionary = {'training_loss': [], 'training':[], 'test':[]}
|
201 |
+
start_train_time = time.time()
|
202 |
+
|
203 |
+
summary_writer = SummaryWriter(log_dir=super_folder+'argument_models/' + save_folder_name)
|
204 |
+
train_iter = 0
|
205 |
+
|
206 |
+
for t in trange(50):
|
207 |
+
model.train()
|
208 |
+
if t>0 and (t+1)%args.decay_term ==0:
|
209 |
+
learning_rate *= args.decay
|
210 |
+
optimizer = AdamW(model.parameters(), lr=learning_rate)
|
211 |
+
avg_training_loss = 0.0
|
212 |
+
training_acc = 0.0
|
213 |
+
for b in trange(int(input_ids.shape[0]/N)):
|
214 |
+
input_ids_batch = input_ids[N*b:N*(b+1)].to(device)
|
215 |
+
labels_batch = label_train[N*b:N*(b+1)].to(device)
|
216 |
+
attention_mask_batch = attention_mask[N*b:N*(b+1)].to(device)
|
217 |
+
optimizer.zero_grad()
|
218 |
+
#forward pass
|
219 |
+
outputs = model(input_ids_batch, attention_mask=attention_mask_batch, labels=labels_batch.view(1,-1))
|
220 |
+
loss = outputs.loss
|
221 |
+
summary_writer.add_scalar('train/loss', loss, train_iter)
|
222 |
+
|
223 |
+
|
224 |
+
num_acc = accurate_total(y_pred=outputs.logits, y_batch=labels_batch.view(-1))
|
225 |
+
#if t ==0:
|
226 |
+
# print("loss at step ", t, " : ", loss.item())
|
227 |
+
loss.backward()
|
228 |
+
train_iter+=1
|
229 |
+
optimizer.step()
|
230 |
+
#
|
231 |
+
avg_training_loss += loss
|
232 |
+
training_acc += num_acc
|
233 |
+
avg_training_loss *= 1/int(input_ids.shape[0]/N)
|
234 |
+
#Print & Evaluate
|
235 |
+
if args.verbose:
|
236 |
+
print("loss at step ", t, " : ", loss.item())
|
237 |
+
print("training accuracy: ", training_acc/input_ids.shape[0])
|
238 |
+
#evaluate
|
239 |
+
with torch.no_grad():
|
240 |
+
model.eval()
|
241 |
+
outputs_val_seen = model(input_ids_val_seen, attention_mask=attention_mask_val_seen)
|
242 |
+
val_seen_acc = accuracy(y_pred=outputs_val_seen.logits, y_batch=label_val_seen.view(-1))
|
243 |
+
if args.verbose:
|
244 |
+
print("validation accuracy (seen): ", val_seen_acc)
|
245 |
+
del outputs_val_seen
|
246 |
+
|
247 |
+
model_name = 'epoch_' + str(t) + '.pt'
|
248 |
+
torch.save(model.state_dict(), super_folder+'argument_models/' + save_folder_name + model_name)
|
249 |
+
accuracy_dictionary['training_loss'].append(loss.item())
|
250 |
+
accuracy_dictionary['training'].append(training_acc/input_ids.shape[0])
|
251 |
+
accuracy_dictionary['test'].append(val_seen_acc)
|
252 |
+
summary_writer.add_scalar('valid/accuracy', val_seen_acc, train_iter)
|
253 |
+
|
254 |
+
print("Epoch " + str(t) + "\n")
|
255 |
+
print("training loss: " + str(accuracy_dictionary['training_loss'][t]))
|
256 |
+
print("training accuracy: " + str(accuracy_dictionary['training'][t]))
|
257 |
+
print("validation (seen) accuracy: " + str(accuracy_dictionary['test'][t]))
|
258 |
+
|
259 |
+
#Get the highest accuracy and delete the rest
|
260 |
+
highest_test = np.argwhere(accuracy_dictionary['test'] == np.amax(accuracy_dictionary['test']))
|
261 |
+
highest_test = highest_test.flatten().tolist()
|
262 |
+
|
263 |
+
training_acc_highest_h = np.argmax([accuracy_dictionary['training'][h].detach().cpu().numpy() for h in highest_test])
|
264 |
+
best_t = highest_test[training_acc_highest_h]
|
265 |
+
print("The best model is 'epoch_"+ str(best_t)+ '.pt')
|
266 |
+
|
267 |
+
#Delete every model except for best_t
|
268 |
+
file_path = super_folder+'argument_models/' + save_folder_name + "epoch_" + str(best_t) + ".pt"
|
269 |
+
if os.path.isfile(file_path):
|
270 |
+
for CleanUp in glob.glob(super_folder+'argument_models/' + save_folder_name + '*.pt'):
|
271 |
+
if not CleanUp.endswith(file_path):
|
272 |
+
os.remove(CleanUp)
|
273 |
+
|
274 |
+
#Save training/ test accuracy dictionary in a txt file
|
275 |
+
f = open(super_folder+ 'argument_models/' + save_folder_name + "training_log.txt", "w")
|
276 |
+
for t in range(100):
|
277 |
+
|
278 |
+
if t >= len(accuracy_dictionary['training_loss']):
|
279 |
+
break
|
280 |
+
if t == best_t:
|
281 |
+
f.write("===========================================\n")
|
282 |
+
# print("===========================================\n")
|
283 |
+
f.write("Epoch " + str(t) + "\n")
|
284 |
+
f.write("training loss: " + str(accuracy_dictionary['training_loss'][t]) + "\n")
|
285 |
+
f.write("training accuracy: " + str(accuracy_dictionary['training'][t]) + "\n")
|
286 |
+
f.write("validation (seen) accuracy: " + str(accuracy_dictionary['test'][t]) + "\n")
|
287 |
+
|
288 |
+
# print("Epoch " + str(t) + "\n")
|
289 |
+
# print("training loss: " + str(accuracy_dictionary['training_loss'][t]) + "\n")
|
290 |
+
# print("training accuracy: " + str(accuracy_dictionary['training'][t]) + "\n")
|
291 |
+
# print("validation (seen) accuracy: " + str(accuracy_dictionary['test'][t]) + "\n")
|
292 |
+
if t == best_t:
|
293 |
+
f.write("===========================================\n")
|
294 |
+
# print("===========================================\n")
|
295 |
+
f.close()
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
#Print training/ test accuracy for t
|
300 |
+
print("Saved and finished")
|
301 |
+
|
models/instructions_processed_LP/BERT/train_bert_base.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Wed Jan 20 11:41:44 2021
|
5 |
+
|
6 |
+
@author: soyeonmin
|
7 |
+
"""
|
8 |
+
|
9 |
+
#Train BERT on ALFRED data
|
10 |
+
import tensorboard
|
11 |
+
from tensorboardX import SummaryWriter
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
import os
|
17 |
+
from tqdm import trange
|
18 |
+
import argparse
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument('-lr','--learning_rate', type=float, help="learning rate")
|
21 |
+
parser.add_argument('-model_type','--model_type', type=str, default='bert-base', help="one of roberta-large, roberta-base, bert-base, bert-large")
|
22 |
+
parser.add_argument('-task_desc_only','--task_desc_only', action='store_true', help="one of roberta-large, roberta-base, bert-base, bert-large")
|
23 |
+
parser.add_argument('--no_appended', action='store_true')
|
24 |
+
|
25 |
+
|
26 |
+
args = parser.parse_args()
|
27 |
+
|
28 |
+
|
29 |
+
##########
|
30 |
+
#### 1. Data loading and processing
|
31 |
+
#########
|
32 |
+
|
33 |
+
#Set device
|
34 |
+
if torch.cuda.is_available():
|
35 |
+
device = torch.device('cuda')
|
36 |
+
else:
|
37 |
+
device = torch.device('cpu')
|
38 |
+
|
39 |
+
|
40 |
+
import pickle
|
41 |
+
directory = 'data/alfred_data/'
|
42 |
+
# train_set = pickle.load(open('data/alfred_data/train_text_with_ppdl_low_appended.p', 'rb'))
|
43 |
+
# val_set_seen = pickle.load(open('data/alfred_data/val_seen_text_with_ppdl_low_appended.p', 'rb'))
|
44 |
+
# val_set_unseen = pickle.load(open('data/alfred_data/val_unseen_text_with_ppdl_low_appended.p', 'rb'))
|
45 |
+
|
46 |
+
# train_set = pickle.load(open('data/alfred_data/train_text_with_ppdl_low_appended_new_split.p', 'rb'))
|
47 |
+
# val_set_seen = pickle.load(open('data/alfred_data/valid_seen_text_with_ppdl_low_appended_new_split.p', 'rb'))
|
48 |
+
# val_set_unseen = pickle.load(open('data/alfred_data/valid_unseen_text_with_ppdl_low_appended_new_split.p', 'rb'))
|
49 |
+
|
50 |
+
train_set = pickle.load(open('data/alfred_data/train_text_with_ppdl_low_appended_new_split_oct24.p', 'rb'))
|
51 |
+
val_set_seen = pickle.load(open('data/alfred_data/valid_seen_text_with_ppdl_low_appended_new_split_oct24.p', 'rb'))
|
52 |
+
val_set_unseen = pickle.load(open('data/alfred_data/valid_unseen_text_with_ppdl_low_appended_new_split_oct24.p', 'rb'))
|
53 |
+
|
54 |
+
#Now process like huggingface
|
55 |
+
if args.no_appended:
|
56 |
+
x_train = train_set['x']; y_train = train_set['y']
|
57 |
+
else:
|
58 |
+
x_train = train_set['x_low']; y_train = train_set['y']
|
59 |
+
labels = torch.tensor(y_train).to(device)
|
60 |
+
|
61 |
+
if args.model_type in ['bert-base', 'bert-large'] :
|
62 |
+
from transformers import BertTokenizer as Tokenizer
|
63 |
+
tok_type = args.model_type + '-uncased'
|
64 |
+
from transformers import BertForSequenceClassification as BertModel
|
65 |
+
elif args.model_type in ['roberta-base', 'roberta-large']:
|
66 |
+
from transformers import RobertaTokenizer as Tokenizer
|
67 |
+
tok_type = args.model_type
|
68 |
+
from transformers import RobertaForSequenceClassification as BertModel
|
69 |
+
tokenizer = Tokenizer.from_pretrained(tok_type)
|
70 |
+
|
71 |
+
model = BertModel.from_pretrained(tok_type, num_labels=7).to(device)
|
72 |
+
encoding = tokenizer(x_train, return_tensors='pt', padding=True, truncation=True)
|
73 |
+
|
74 |
+
input_ids = encoding['input_ids'].to(device) #has shape [21025, 37]
|
75 |
+
attention_mask = encoding['attention_mask'].to(device) #also has shape [21025, 37]
|
76 |
+
|
77 |
+
|
78 |
+
#Get input_ids, attention_mask for val_seen, val_unseen
|
79 |
+
x_val_seen = val_set_seen['x_low']; y_val_seen = val_set_seen['y']
|
80 |
+
if args.no_appended:
|
81 |
+
x_val_seen = val_set_seen['x']
|
82 |
+
#Just sample a few from x_val_seen, y_val_seen
|
83 |
+
random.seed(0)
|
84 |
+
sampled_val_seen = random.sample(range(len(x_val_seen)), 100)
|
85 |
+
x_val_seen = [x_val_seen[i] for i in sampled_val_seen]; y_val_seen = [y_val_seen[i] for i in sampled_val_seen]
|
86 |
+
labels_val_seen = torch.tensor(y_val_seen).to(device)
|
87 |
+
encoding_v_s = tokenizer(x_val_seen, return_tensors='pt', padding=True, truncation=True)
|
88 |
+
input_ids_val_seen = encoding_v_s['input_ids'].to(device)
|
89 |
+
attention_mask_val_seen = encoding_v_s['attention_mask'].to(device)
|
90 |
+
|
91 |
+
x_val_unseen = val_set_unseen['x_low']; y_val_unseen = val_set_unseen['y']
|
92 |
+
if args.no_appended:
|
93 |
+
x_val_unseen = val_set_unseen['x']
|
94 |
+
random.seed(1)
|
95 |
+
sampled_val_unseen = random.sample(range(len(x_val_unseen)), 100)
|
96 |
+
x_val_unseen = [x_val_unseen[i] for i in sampled_val_unseen]; y_val_unseen = [y_val_unseen[i] for i in sampled_val_unseen]
|
97 |
+
labels_val_unseen = torch.tensor(y_val_unseen).to(device)
|
98 |
+
encoding_v_us = tokenizer(x_val_unseen, return_tensors='pt', padding=True, truncation=True)
|
99 |
+
input_ids_val_unseen = encoding_v_us['input_ids'].to(device)
|
100 |
+
attention_mask_val_unseen = encoding_v_us['attention_mask'].to(device)
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
##########
|
105 |
+
#### 2. Do training
|
106 |
+
#########
|
107 |
+
|
108 |
+
model.train()
|
109 |
+
from transformers import AdamW
|
110 |
+
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
|
111 |
+
N = 16 #batch size
|
112 |
+
if args.no_appended:
|
113 |
+
N=64
|
114 |
+
save_folder_name = args.model_type + '_lr_' + str(args.learning_rate) +'/'
|
115 |
+
if args.no_appended:
|
116 |
+
super_folder = 'saved_models_noappended_new_split_oct24/'
|
117 |
+
else:
|
118 |
+
super_folder = 'saved_models_appended_new_split_oct24/'
|
119 |
+
if not os.path.exists(super_folder + save_folder_name):
|
120 |
+
os.makedirs(super_folder + save_folder_name)
|
121 |
+
|
122 |
+
|
123 |
+
def accuracy(y_pred, y_batch):
|
124 |
+
#y_pred has shape [batch, no_classes]
|
125 |
+
maxed = torch.max(y_pred, 1)
|
126 |
+
y_hat = maxed.indices
|
127 |
+
num_accurate = torch.sum((y_hat == y_batch).float())
|
128 |
+
train_accuracy = num_accurate/ y_hat.shape[0]
|
129 |
+
return train_accuracy.item()
|
130 |
+
|
131 |
+
def accurate_total(y_pred, y_batch):
|
132 |
+
#y_pred has shape [batch, no_classes]
|
133 |
+
maxed = torch.max(y_pred, 1)
|
134 |
+
y_hat = maxed.indices
|
135 |
+
num_accurate = torch.sum((y_hat == y_batch).float())
|
136 |
+
return num_accurate
|
137 |
+
|
138 |
+
start_train_time = time.time()
|
139 |
+
summary_writer = SummaryWriter(log_dir=super_folder + save_folder_name)
|
140 |
+
train_iter = 0
|
141 |
+
for t in trange(50):
|
142 |
+
model.train()
|
143 |
+
avg_training_loss = 0.0
|
144 |
+
training_acc = 0.0
|
145 |
+
for b in trange(int(input_ids.shape[0]/N)):
|
146 |
+
input_ids_batch = input_ids[N*b:N*(b+1)].to(device)
|
147 |
+
labels_batch = labels[N*b:N*(b+1)].to(device)
|
148 |
+
attention_mask_batch = attention_mask[N*b:N*(b+1)].to(device)
|
149 |
+
optimizer.zero_grad()
|
150 |
+
#forward pass
|
151 |
+
outputs = model(input_ids_batch, attention_mask=attention_mask_batch, labels=labels_batch.view(1,-1))
|
152 |
+
loss = outputs.loss
|
153 |
+
summary_writer.add_scalar('train/loss', loss, train_iter)
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
num_acc = accurate_total(y_pred=outputs.logits, y_batch=labels_batch.view(-1))
|
158 |
+
loss.backward()
|
159 |
+
train_iter+=1
|
160 |
+
optimizer.step()
|
161 |
+
#
|
162 |
+
avg_training_loss += loss
|
163 |
+
training_acc += num_acc
|
164 |
+
avg_training_loss *= 1/int(input_ids.shape[0]/N)
|
165 |
+
#Print & Evaluate
|
166 |
+
if t % 1 ==0:
|
167 |
+
with torch.no_grad():
|
168 |
+
model.eval()
|
169 |
+
print("loss at step ", t, " : ", loss.item())
|
170 |
+
print("training accuracy: ", training_acc/input_ids.shape[0])
|
171 |
+
#evaluate
|
172 |
+
outputs_val_seen = model(input_ids_val_seen, attention_mask=attention_mask_val_seen)
|
173 |
+
print("validation accuracy (seen): ", accuracy(y_pred=outputs_val_seen.logits, y_batch=labels_val_seen.view(-1)))
|
174 |
+
del outputs_val_seen
|
175 |
+
outputs_val_unseen = model(input_ids_val_unseen, attention_mask=attention_mask_val_unseen)
|
176 |
+
print("validation accuracy (unseen): ", accuracy(y_pred=outputs_val_unseen.logits, y_batch=labels_val_unseen.view(-1)))
|
177 |
+
|
178 |
+
model_name = 'epoch_' + str(t) + '.pt'
|
179 |
+
torch.save(model.state_dict(), super_folder + save_folder_name + model_name)
|
models/instructions_processed_LP/compare_BERT_pred_with_GT_oct24.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
|
3 |
+
splits = ["valid_seen","valid_unseen","tests_seen", "tests_unseen"]
|
4 |
+
split = "valid_unseen"
|
5 |
+
|
6 |
+
pred = pickle.load(open('instruction2_params_'+split+'_appended_new_split_oct24.p','rb'))
|
7 |
+
gt = pickle.load(open('instruction2_params_'+split+'_new_split_GT_oct24.p','rb'))
|
8 |
+
|
9 |
+
task_desc_list = list(pred.keys())
|
10 |
+
# gt = list(gt.items())
|
11 |
+
task_types = {"pick_cool_then_place_in_recep": 0, "pick_and_place_with_movable_recep": 1, "pick_and_place_simple": 2, "pick_two_obj_and_place": 3, "pick_heat_then_place_in_recep": 4, "look_at_obj_in_light": 5, "pick_clean_then_place_in_recep": 6}
|
12 |
+
|
13 |
+
c0=0; c1=0; c2=0; c3=0; c4=0; c5=0; c6=0
|
14 |
+
nt=0; nm=0; no=0; np=0; ns=0; nt=0
|
15 |
+
parent_pred = []; parent_gt=[]; parents = []
|
16 |
+
# num_of_tables_in_parent = 0
|
17 |
+
for i in range(len(pred)):
|
18 |
+
n = 0
|
19 |
+
task = task_desc_list[i]
|
20 |
+
print("##########################################################################")
|
21 |
+
print(task)
|
22 |
+
if pred[task]['task_type'] == gt[task]['task_type']:
|
23 |
+
nt += 1; n += 1
|
24 |
+
else:
|
25 |
+
print('Task type: ', pred[task]['task_type'], gt[task]['task_type'])
|
26 |
+
|
27 |
+
if pred[task]['mrecep_target'] == gt[task]['mrecep_target']:
|
28 |
+
nm += 1; n += 1
|
29 |
+
else:
|
30 |
+
print('mrecep: ', pred[task]['mrecep_target'], gt[task]['mrecep_target'])
|
31 |
+
|
32 |
+
if pred[task]['object_target'] == gt[task]['object_target']:
|
33 |
+
no += 1; n += 1
|
34 |
+
else:
|
35 |
+
print('object_target: ', pred[task]['object_target'], gt[task]['object_target'])
|
36 |
+
|
37 |
+
if pred[task]['parent_target'] == gt[task]['parent_target']:
|
38 |
+
np += 1; n += 1
|
39 |
+
|
40 |
+
else:
|
41 |
+
print('parent_target: ', pred[task]['parent_target'], gt[task]['parent_target'])
|
42 |
+
parent_pred.append(pred[task]['parent_target'])
|
43 |
+
parent_gt.append(gt[task]['parent_target'])
|
44 |
+
parents.append([pred[task]['parent_target'],gt[task]['parent_target']])
|
45 |
+
|
46 |
+
if pred[task]['sliced'] == gt[task]['sliced']:
|
47 |
+
ns += 1; n += 1
|
48 |
+
else:
|
49 |
+
print('sliced: ', pred[task]['sliced'], gt[task]['sliced'])
|
50 |
+
# if pred[task]['toggle_target'] == gt[task]['toggle_target']:
|
51 |
+
# nt += 1; n += 1
|
52 |
+
print("##########################################################################")
|
53 |
+
|
54 |
+
if n == 0:
|
55 |
+
c0 += 1
|
56 |
+
elif n == 1:
|
57 |
+
c1 += 1
|
58 |
+
elif n == 2:
|
59 |
+
c2 += 1
|
60 |
+
elif n == 3:
|
61 |
+
c3 += 1
|
62 |
+
elif n == 4:
|
63 |
+
c4 += 1
|
64 |
+
elif n == 5:
|
65 |
+
c5 += 1
|
66 |
+
|
67 |
+
print("################################################")
|
68 |
+
print(split)
|
69 |
+
print(f"Total number of task_desc: {len(pred), len(gt)}\n")
|
70 |
+
print(f"task_type accuracy: {round(nt/len(pred)*100, 2)}")
|
71 |
+
print(f"mrecep_target accuracy: {round(nm/len(pred)*100, 2)}")
|
72 |
+
print(f"object_target accuracy: {round(no/len(pred)*100, 2)}")
|
73 |
+
print(f"parent_target accuracy: {round(np/len(pred)*100, 2)}")
|
74 |
+
print(f"sliced accuracy: {round(ns/len(pred)*100, 2)}")
|
75 |
+
print()
|
76 |
+
print(f"0 Correct: {round(c0/len(pred)*100, 2)}")
|
77 |
+
print(f"1 Correct: {round(c1/len(pred)*100, 2)}")
|
78 |
+
print(f"2 Correct: {round(c2/len(pred)*100, 2)}")
|
79 |
+
print(f"3 Correct: {round(c3/len(pred)*100, 2)}")
|
80 |
+
print(f"4 Correct: {round(c4/len(pred)*100, 2)}")
|
81 |
+
print(f"5 Correct: {round(c5/len(pred)*100, 2)}")
|
82 |
+
print('\nTotal all correct accuracy: ', str(round(c5/len(pred)*100, 2)))
|
83 |
+
print("################################################")
|
models/instructions_processed_LP/instruction2_params_tests_seen_appended_new_split_oct24.p
ADDED
Binary file (178 kB). View file
|
|
models/instructions_processed_LP/instruction2_params_tests_seen_new_split_GT_oct24.p
ADDED
Binary file (244 kB). View file
|
|
models/instructions_processed_LP/instruction2_params_tests_unseen_appended_new_split_oct24.p
ADDED
Binary file (173 kB). View file
|
|
models/instructions_processed_LP/instruction2_params_tests_unseen_new_split_GT_oct24.p
ADDED
Binary file (233 kB). View file
|
|
models/instructions_processed_LP/instruction2_params_valid_seen_appended_new_split_oct24.p
ADDED
Binary file (108 kB). View file
|
|
models/instructions_processed_LP/instruction2_params_valid_seen_new_split_GT_oct24.p
ADDED
Binary file (146 kB). View file
|
|
models/instructions_processed_LP/instruction2_params_valid_unseen_appended_new_split_oct24.p
ADDED
Binary file (103 kB). View file
|
|
models/instructions_processed_LP/instruction2_params_valid_unseen_new_split_GT_oct24.p
ADDED
Binary file (140 kB). View file
|
|