Spaces:
Running
on
T4
Running
on
T4
Jacob Gershon
commited on
Commit
•
29f8f7e
1
Parent(s):
59a9ccf
fixed but in partial seq diffusion
Browse files- utils/sampler.py +7 -1
utils/sampler.py
CHANGED
@@ -291,7 +291,13 @@ class SEQDIFF_sampler:
|
|
291 |
self.features['xyz_t'] = torch.full((1,1,len(self.args['sequence']),27,3), np.nan)
|
292 |
|
293 |
self.features['mask_str'] = torch.zeros(len(self.args['sequence'])).long()[None,:].bool()
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
self.features['blank_mask'] = torch.ones(self.features['mask_str'].size()[-1])[None,:].bool()
|
296 |
|
297 |
self.features['idx_pdb'] = torch.tensor([i for i in range(len(self.args['sequence']))])[None,:]
|
|
|
291 |
self.features['xyz_t'] = torch.full((1,1,len(self.args['sequence']),27,3), np.nan)
|
292 |
|
293 |
self.features['mask_str'] = torch.zeros(len(self.args['sequence'])).long()[None,:].bool()
|
294 |
+
|
295 |
+
#added check for if in partial diffusion mode will mask
|
296 |
+
if self.args['sampling_temp'] == 1.0:
|
297 |
+
self.features['mask_seq'] = torch.tensor([0 if x == 'X' else 1 for x in self.args['sequence']]).long()[None,:].bool()
|
298 |
+
else:
|
299 |
+
self.features['mask_seq'] = torch.zeros(len(self.args['sequence'])).long()[None,:].bool()
|
300 |
+
|
301 |
self.features['blank_mask'] = torch.ones(self.features['mask_str'].size()[-1])[None,:].bool()
|
302 |
|
303 |
self.features['idx_pdb'] = torch.tensor([i for i in range(len(self.args['sequence']))])[None,:]
|