Alexander Seifert commited on
Commit
bb162b6
1 Parent(s): 408486e

add randomize_sample option

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. src/data.py +6 -2
  3. src/load.py +4 -1
  4. src/subpages/home.py +9 -1
README.md CHANGED
@@ -19,7 +19,7 @@ Error Analysis is an important but often overlooked part of the data science pro
19
 
20
  ### Activations
21
 
22
- A group of neurons tend to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model.
23
 
24
 
25
  ### Embeddings
 
19
 
20
  ### Activations
21
 
22
+ A group of neurons tends to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model.
23
 
24
 
25
  ### Embeddings
src/data.py CHANGED
@@ -11,7 +11,9 @@ from src.utils import device, tokenizer_hash_funcs
11
 
12
 
13
  @st.cache(allow_output_mutation=True)
14
- def get_data(ds_name: str, config_name: str, split_name: str, split_sample_size: int) -> Dataset:
 
 
15
  """Loads a Dataset from the HuggingFace hub (if not already loaded).
16
 
17
  Uses `datasets.load_dataset` to load the dataset (see its documentation for additional details).
@@ -25,7 +27,9 @@ def get_data(ds_name: str, config_name: str, split_name: str, split_sample_size:
25
  Returns:
26
  Dataset: A Dataset object.
27
  """
28
- ds: DatasetDict = load_dataset(ds_name, name=config_name, use_auth_token=True).shuffle(seed=0) # type: ignore
 
 
29
  split = ds[split_name].select(range(split_sample_size))
30
  return split
31
 
 
11
 
12
 
13
  @st.cache(allow_output_mutation=True)
14
+ def get_data(
15
+ ds_name: str, config_name: str, split_name: str, split_sample_size: int, randomize_sample: bool
16
+ ) -> Dataset:
17
  """Loads a Dataset from the HuggingFace hub (if not already loaded).
18
 
19
  Uses `datasets.load_dataset` to load the dataset (see its documentation for additional details).
 
27
  Returns:
28
  Dataset: A Dataset object.
29
  """
30
+ ds: DatasetDict = load_dataset(ds_name, name=config_name, use_auth_token=True).shuffle(
31
+ seed=0 if randomize_sample else None
32
+ ) # type: ignore
33
  split = ds[split_name].select(range(split_sample_size))
34
  return split
35
 
src/load.py CHANGED
@@ -37,6 +37,7 @@ def load_context(
37
  ds_config_name: str,
38
  ds_split_name: str,
39
  split_sample_size: int,
 
40
  **kw_args,
41
  ) -> Context:
42
  """Utility method loading (almost) everything we need for the application.
@@ -63,7 +64,9 @@ def load_context(
63
  collator = get_collator(tokenizer)
64
 
65
  # load data related stuff
66
- split: Dataset = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size)
 
 
67
  tags = split.features["ner_tags"].feature
68
  split_encoded, word_ids, ids = encode_dataset(split, tokenizer)
69
 
 
37
  ds_config_name: str,
38
  ds_split_name: str,
39
  split_sample_size: int,
40
+ randomize_sample: bool,
41
  **kw_args,
42
  ) -> Context:
43
  """Utility method loading (almost) everything we need for the application.
 
64
  collator = get_collator(tokenizer)
65
 
66
  # load data related stuff
67
+ split: Dataset = get_data(
68
+ ds_name, ds_config_name, ds_split_name, split_sample_size, randomize_sample
69
+ )
70
  tags = split.features["ner_tags"].feature
71
  split_encoded, word_ids, ids = encode_dataset(split, tokenizer)
72
 
src/subpages/home.py CHANGED
@@ -45,6 +45,7 @@ class HomePage(Page):
45
  "ds_split_name": "validation",
46
  "ds_config_name": _CONFIG_NAME,
47
  "split_sample_size": 512,
 
48
  }
49
 
50
  def render(self, context: Optional[Context] = None):
@@ -118,11 +119,18 @@ class HomePage(Page):
118
  key="split_sample_size",
119
  help="Sample size for the split, speeds up processing inside streamlit",
120
  )
 
 
 
 
 
121
  # breakpoint()
122
  # st.form_submit_button("Submit")
123
  st.form_submit_button("Load Model & Data")
124
 
125
- split = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size)
 
 
126
  labels = list(
127
  set([n.split("-")[1] for n in split.features["ner_tags"].feature.names if n != "O"])
128
  )
 
45
  "ds_split_name": "validation",
46
  "ds_config_name": _CONFIG_NAME,
47
  "split_sample_size": 512,
48
+ "randomize_sample": True,
49
  }
50
 
51
  def render(self, context: Optional[Context] = None):
 
119
  key="split_sample_size",
120
  help="Sample size for the split, speeds up processing inside streamlit",
121
  )
122
+ randomize_sample = st.checkbox(
123
+ "Randomize sample",
124
+ key="randomize_sample",
125
+ help="Whether to randomize the sample",
126
+ )
127
  # breakpoint()
128
  # st.form_submit_button("Submit")
129
  st.form_submit_button("Load Model & Data")
130
 
131
+ split = get_data(
132
+ ds_name, ds_config_name, ds_split_name, split_sample_size, randomize_sample # type: ignore
133
+ )
134
  labels = list(
135
  set([n.split("-")[1] for n in split.features["ner_tags"].feature.names if n != "O"])
136
  )