How to Directly Obtain Gene and Cell Embedding Changes Without Predefined Cell States?

#454
by HarryShen666 - opened

Hi ctheodoris,

I noticed that the InSilicoPerturber focuses on predefined cell states and provides averaged summaries of embeddings for perturbations. However, my goal requires directly obtaining the gene and cell embeddings for all my given cells under specific gene perturbations, without averaging or focusing on predefined cell states. These embeddings will be used in downstream tasks beyond predefined cell state analysis.

Could you please guide me on how to achieve this using Geneformer? Are there any modifications or specific parameters in the InSilicoPerturber function allowing this?

Thank you!

Thank you for your question! You can accomplish this using the EmbExtractor to directly output the embeddings. You can change the inputs to reflect the given perturbations you are interested in.

ctheodoris changed discussion status to closed

@ctheodoris
Hi ctheodoris,

Thank you so much for your guidance!

In such a case, I’ve been experimenting with creating perturbed inputs using functions in in_silico_perturber.py (code snippet attached below) that process token perturbations. Is this approach right to generate perturbed inputs for EmbExtractor?

def make_group_perturbation_batch(example,tokens_to_perturb,perturb_type,max_len):
    from geneformer import perturber_utils as pu
    example_input_ids = example["input_ids"]
    example["tokens_to_perturb"] = tokens_to_perturb
    indices_to_perturb = [
        example_input_ids.index(token) if token in example_input_ids else None
        for token in tokens_to_perturb
    ]
    indices_to_perturb = [
        item for item in indices_to_perturb if item is not None
    ]
    if len(indices_to_perturb) > 0:
        example["perturb_index"] = indices_to_perturb
    else:
        # -100 indicates tokens to overexpress are not present in rank value encoding
        example["perturb_index"] = [-100]
    if perturb_type == "delete":
        example = pu.delete_indices(example)
    elif perturb_type == "overexpress":
        example = pu.overexpress_tokens(
            example, max_len, self.special_token
        )
        example["n_overflow"] = pu.calc_n_overflow(
            max_len,
            example["length"],
            tokens_to_perturb,
            indices_to_perturb,
        )
    return example

### define custom tokens_to_perturb,perturb_type, max_len
perturbed_data = filtered_input_data.map(
    make_group_perturbation_batch, fn_kwargs={
        "tokens_to_perturb": tokens_to_perturb,
        "perturb_type": perturb_type,
        "max_len": max_len,
    },num_proc=20
)
### Save perturbed_data to the datapath for EmbExtractor

Thank you in advance for your time and help!

Thank you for following up! The in silico perturber code is complicated and we are not certain what you are trying to do so we cannot really advise directly on this question, but we would recommend checking the perturbed inputs you prepare to ensure they are in the format you expect. If the perturbed inputs are prepared correctly, you can use the EmbExtractor to extract the embeddings and use these in your downstream tasks.

Sign up or log in to comment