Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import io | |
import sys | |
# Function to execute the input code and capture print statements | |
def execute_code(code): | |
# Redirect stdout to capture print statements | |
old_stdout = sys.stdout | |
sys.stdout = mystdout = io.StringIO() | |
global_vars = {"torch": torch} | |
local_vars = {} | |
try: | |
exec(code, global_vars, local_vars) | |
output = mystdout.getvalue() | |
except Exception as e: | |
output = str(e) | |
finally: | |
# Reset redirect. | |
sys.stdout = old_stdout | |
return output, local_vars | |
st.title('PyTorch Code Runner') | |
# Text area for inputting the PyTorch code | |
code_input = st.text_area("Enter your PyTorch code here", height=300, value="""# Create two tensors of different shapes | |
tensor_c = torch.tensor([1, 2, 3]) | |
tensor_d = torch.tensor([[1], [2], [3]]) | |
# Perform addition using broadcasting | |
tensor_broadcast_add = tensor_c + tensor_d | |
print("Broadcast Addition:\\n", tensor_broadcast_add) | |
# Perform element-wise multiplication using broadcasting | |
tensor_broadcast_mul = tensor_c * tensor_d | |
print("Broadcast Multiplication:\\n", tensor_broadcast_mul) | |
""") | |
# Button to execute the code | |
if st.button("Run Code"): | |
# Prepend the import statement | |
code_to_run = "import torch\n" + code_input | |
# Execute the code and capture the output | |
output, variables = execute_code(code_to_run) | |
# Display the output | |
st.subheader('Output') | |
st.text(output) | |
# Display returned variables | |
if variables: | |
st.subheader('Variables') | |
for key, value in variables.items(): | |
st.text(f"{key}: {value}") | |