dinhquangson commited on
Commit
bae4b3b
1 Parent(s): 5cfedf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -9
app.py CHANGED
@@ -54,18 +54,13 @@ client2.create_collection(
54
  ),
55
  )
56
 
57
-
58
-
59
-
60
-
61
-
62
  # Create function to generate embeddings (in batches) for a given dataset split
63
- def generate_embeddings(dataset, batch_size=32):
64
  embeddings = []
65
 
66
  with tqdm(total=len(dataset), desc=f"Generating embeddings for dataset") as pbar:
67
  for i in range(0, len(dataset), batch_size):
68
- batch_sentences = dataset['content'][i:i+batch_size]
69
  batch_embeddings = model.encode(batch_sentences)
70
  embeddings.extend(batch_embeddings)
71
  pbar.update(len(batch_sentences))
@@ -73,7 +68,7 @@ def generate_embeddings(dataset, batch_size=32):
73
  return embeddings
74
 
75
  @app.post("/uploadfile/")
76
- async def create_upload_file(file: UploadFile = File(...)):
77
  file_savePath = join(temp_path,file.filename)
78
 
79
  with open(file_savePath,'wb') as f:
@@ -95,7 +90,7 @@ async def create_upload_file(file: UploadFile = File(...)):
95
  else:
96
  raise NotImplementedError("This feature is not supported yet")
97
  # Generate and append embeddings to the train split
98
- law_embeddings = generate_embeddings(full_dataset)
99
  full_dataset= full_dataset.add_column("embeddings", law_embeddings)
100
 
101
  if not 'uuid' in full_dataset.column_names:
 
54
  ),
55
  )
56
 
 
 
 
 
 
57
  # Create function to generate embeddings (in batches) for a given dataset split
58
+ def generate_embeddings(dataset, text_field, batch_size=32):
59
  embeddings = []
60
 
61
  with tqdm(total=len(dataset), desc=f"Generating embeddings for dataset") as pbar:
62
  for i in range(0, len(dataset), batch_size):
63
+ batch_sentences = dataset[text_field][i:i+batch_size]
64
  batch_embeddings = model.encode(batch_sentences)
65
  embeddings.extend(batch_embeddings)
66
  pbar.update(len(batch_sentences))
 
68
  return embeddings
69
 
70
  @app.post("/uploadfile/")
71
+ async def create_upload_file(file: UploadFile = File(...), text_field: str):
72
  file_savePath = join(temp_path,file.filename)
73
 
74
  with open(file_savePath,'wb') as f:
 
90
  else:
91
  raise NotImplementedError("This feature is not supported yet")
92
  # Generate and append embeddings to the train split
93
+ law_embeddings = generate_embeddings(full_dataset, text_field)
94
  full_dataset= full_dataset.add_column("embeddings", law_embeddings)
95
 
96
  if not 'uuid' in full_dataset.column_names: