import gradio as gr import torch from safetensors.torch import save_model import requests import os def convert_ckpt_to_safetensors(input_path, output_path): # Load the .ckpt file # ⚠️ SECURITY WARNING: # Loading untrusted .ckpt files with torch.load() can execute arbitrary code. # Only load files from trusted sources. obj = torch.load(input_path, map_location='cpu') # Determine if obj is a state dict or a model object if isinstance(obj, dict): # Check for nested 'state_dict' or 'model' keys if 'state_dict' in obj: state_dict = obj['state_dict'] elif 'model' in obj: state_dict = obj['model'] else: # Assume obj is the state dict state_dict = obj elif hasattr(obj, 'state_dict'): # If obj is a model object state_dict = obj.state_dict() else: return "Unsupported checkpoint format." # Save the state dictionary, including shared tensors and LM head try: save_model(state_dict, output_path) except Exception as e: return f"An error occurred during saving: {e}" return "Success" def process(url, uploaded_file): if url: # Download the .ckpt file local_filename = 'model.ckpt' try: with requests.get(url, stream=True) as r: r.raise_for_status() with open(local_filename, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) except Exception as e: return f"
Failed to download file: {e}
" elif uploaded_file is not None: # Save uploaded file local_filename = 'uploaded_model.ckpt' try: with open(local_filename, 'wb') as f: f.write(uploaded_file.read()) except Exception as e: return f"Failed to save uploaded file: {e}
" else: return "Please provide a URL or upload a .ckpt file.
" output_filename = local_filename.replace('.ckpt', '.safetensors') # Convert the .ckpt to .safetensors try: result = convert_ckpt_to_safetensors(local_filename, output_filename) if result != "Success": # Clean up the input file os.remove(local_filename) return f"An error occurred during conversion: {result}
" except Exception as e: # Clean up the input file os.remove(local_filename) return f"An exception occurred: {e}
" # Clean up the input file os.remove(local_filename) # Provide a download link for the output file return gr.File.update(value=output_filename, visible=True) iface = gr.Interface( fn=process, inputs=[ gr.Textbox(label="URL of .ckpt file", placeholder="Enter the URL here"), gr.File(label="Or upload a .ckpt file", file_types=['.ckpt']) ], outputs=gr.File(label="Converted .safetensors file"), title="CKPT to SafeTensors Converter", description=""" Convert .ckpt files to .safetensors format. Provide a URL or upload your .ckpt file. **Security Warning:** Loading .ckpt files can execute arbitrary code. Only use files from trusted sources. """ ) iface.launch()