jykoh commited on
Commit
9d41210
1 Parent(s): 819cc0a

Add truncation

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. fromage/models.py +6 -1
app.py CHANGED
@@ -212,5 +212,5 @@ with gr.Blocks(css=css) as demo:
212
  save_button.click(None, [], [], _js=save_js)
213
 
214
 
215
- demo.queue(concurrency_count=1, api_open=False, max_size=16)
216
- demo.launch(debug=True, server_name="0.0.0.0")
 
212
  save_button.click(None, [], [], _js=save_js)
213
 
214
 
215
+ # demo.queue(concurrency_count=1, api_open=False, max_size=16)
216
+ demo.launch(debug=True, server_name="127.0.0.1")
fromage/models.py CHANGED
@@ -525,6 +525,11 @@ class Fromage(nn.Module):
525
  raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
526
  input_embs = torch.cat(input_embs, dim=1)
527
  input_ids = torch.cat(input_ids, dim=1)
 
 
 
 
 
528
 
529
  print('L529 called')
530
  if num_words == 0:
@@ -635,7 +640,7 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
635
  assert len(ret_token_idx) == 1, ret_token_idx
636
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
637
 
638
- debug = False
639
  if debug:
640
  model_kwargs['opt_version'] = 'facebook/opt-125m'
641
  model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
 
525
  raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
526
  input_embs = torch.cat(input_embs, dim=1)
527
  input_ids = torch.cat(input_ids, dim=1)
528
+ # Trim to a reasonable max length, for demo purposes.
529
+ start_idx = max(input_embs.shape[1] - 512, 0)
530
+ input_embs = input_embs[:, start_idx:, :]
531
+ input_ids = input_ids[:, start_idx:]
532
+ print('input_embs.shape', input_embs.shape)
533
 
534
  print('L529 called')
535
  if num_words == 0:
 
640
  assert len(ret_token_idx) == 1, ret_token_idx
641
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
642
 
643
+ debug = True
644
  if debug:
645
  model_kwargs['opt_version'] = 'facebook/opt-125m'
646
  model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'