flux_trainer / app.py
OmPrakashSingh1704's picture
Update app.py
cff43cf verified
raw
history blame contribute delete
No virus
3.84 kB
import subprocess
import os
import platform
import sys
def run_command(command, cwd=None, shell=False):
"""Helper function to run a shell command."""
try:
result = subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd, text=True, shell=shell)
print(result.stdout)
except subprocess.CalledProcessError as e:
print(f"Error occurred while running command: {command}")
print(e.stderr)
sys.exit(1)
# Step 1: Define the repository URL and clone path
repo_url = 'https://github.com/ostris/ai-toolkit.git'
clone_path = './ai-toolkit'
# Check if the 'ai-toolkit' directory already exists
if not os.path.exists(clone_path):
# Clone the repository if it doesn't exist
run_command(['git', 'clone', repo_url, clone_path])
# Step 2: Change directory to the cloned repository
repo_dir = os.path.join(os.getcwd(), 'ai-toolkit')
# Step 3: Update submodules
run_command(['git', 'submodule', 'update', '--init', '--recursive'], cwd=repo_dir)
# Step 4: Install torch and torchvision
run_command([sys.executable, '-m', 'pip', 'install', 'torch', 'torchvision', '--index-url', 'https://download.pytorch.org/whl/cu121'], cwd=repo_dir)
# Step 5: Install dependencies from requirements.txt
run_command([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'], cwd=repo_dir)
# Step 6: Install mediapipe (to handle the missing dependency)
run_command([sys.executable, '-m', 'pip', 'install', 'mediapipe'], cwd=repo_dir)
# Step 7: Modify the flux_train_ui.py to handle Gradio share and CUDA availability issues
flux_train_ui_path = os.path.join(repo_dir, 'flux_train_ui.py')
# Ensure Gradio does not try to set share=True
with open(flux_train_ui_path, 'r') as file:
lines = file.readlines()
with open(flux_train_ui_path, 'w') as file:
for line in lines:
# Modify Gradio 'share=True' to 'share=False'
if 'launch(share=True)' in line:
line = line.replace('share=True', 'share=False')
file.write(line)
# Modify flux_train_ui.py to handle CUDA availability for torch
with open(flux_train_ui_path, 'a') as file:
file.write('\nimport torch\n')
file.write('if not torch.cuda.is_available():\n')
file.write(' print("CUDA is not available, running on CPU.")\n')
file.write(' # Adjust logic to run on CPU if necessary\n')
from huggingface_hub import login
login(token=os.getenv("TOKEN"))
# Step 8: Run the flux_train_ui.py file
run_command([sys.executable, 'flux_train_ui.py'], cwd=repo_dir)
print("Setup completed successfully.")
else:
repo_dir = os.path.join(os.getcwd(), 'ai-toolkit')
flux_train_ui_path = os.path.join(repo_dir, 'flux_train_ui.py')
# Ensure Gradio does not try to set share=True
with open(flux_train_ui_path, 'r') as file:
lines = file.readlines()
with open(flux_train_ui_path, 'w') as file:
for line in lines:
# Modify Gradio 'share=True' to 'share=False'
if 'launch(share=True)' in line:
line = line.replace('share=True', 'share=False')
file.write(line)
# Modify flux_train_ui.py to handle CUDA availability for torch
with open(flux_train_ui_path, 'a') as file:
file.write('\nimport torch\n')
file.write('if not torch.cuda.is_available():\n')
file.write(' print("CUDA is not available, running on CPU.")\n')
file.write(' # Adjust logic to run on CPU if necessary\n')
from huggingface_hub import login
login(token=os.getenv("TOKEN"))
run_command([sys.executable, 'flux_train_ui.py'], cwd=repo_dir)
print("The 'ai-toolkit' directory already exists. No need to clone.")