import spaces import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM # Load the implicit CoT model implicit_cot_model_name = 'yuntian-deng/implicit-cot-math-mistral7b' implicit_cot_model = AutoModelForCausalLM.from_pretrained(implicit_cot_model_name, torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(implicit_cot_model_name) implicit_cot_model.to('cuda' if torch.cuda.is_available() else 'cpu') implicit_cot_model.eval() # Constants MAX_RESULT_TOKENS = 10 @spaces.GPU def predict_answer(question): try: input_text = ' '.join(question.split()).strip() + ' ' + tokenizer.eos_token print (input_text) inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu') implicit_cot_model.to('cuda' if torch.cuda.is_available() else 'cpu') input_ids = inputs['input_ids'] #print (input_ids) outputs = implicit_cot_model.generate(input_ids=input_ids, max_new_tokens=MAX_RESULT_TOKENS, do_sample=False) #print (outputs) prediction = tokenizer.decode(outputs[0, input_ids.shape[-1]:], skip_special_tokens=True) except Exception as e: prediction = f'{e}' return prediction demo = gr.Interface( fn=predict_answer, inputs=[ gr.Textbox(label='Question', value='Asumi\'s bookshelf has 120 books. She has 10 books on history, twice that many books on literature, and the rest are science books. How many science books does Asumi have?'), ], outputs=[ gr.Textbox(label='Implicit CoT Prediction'), ], title='Solving Grade School Math Problems without Intermediate Reasoning Steps', description='This demo showcases Mistral-7B\'s ability to solve grade school math problems without producing intermediate steps, using our stepwise internalization approach linked below.', article=""" - [Paper 1: Implicit Chain of Thought Reasoning via Knowledge Distillation](https://arxiv.org/pdf/2311.01460) - [Paper 2: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838) - [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step) - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036) """, clear_btn=None, submit_btn="Get Answer!", live=False, concurrency_limit=1 ) demo.queue(max_size=5).launch()