Here we can see the different settings for TGI. Be sure to read through them and decide which settings are most important for your use-case.

Here are some of the most important ones for our purposes:
- `--model-id`
- `--quantize` Most of the time you want to quantize to save memory
- `--max-input-tokens` We will use a high number for RAG and a low number for Chat
- `--max-total-tokens` 
- `--max-batch-size` 
- `--max-batch-total-tokens`

These are just used because we are on spaces, and dont want to conflict with the spaces server:
- `--hostname`
- `--port`

In [2]:
!text-generation-launcher -h

Text Generation Launcher

[1m[4mUsage:[0m [1mtext-generation-launcher[0m [OPTIONS]

[1m[4mOptions:[0m
      [1m--model-id[0m <MODEL_ID>
          The name of the model to load. Can be a MODEL_ID as listed on <https://hf.co/models> like `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`. Or it can be a local directory containing the necessary files as saved by `save_pretrained(...)` methods of transformers [env: MODEL_ID=] [default: bigscience/bloom-560m]
      [1m--revision[0m <REVISION>
          The actual revision of the model if you're referring to a model on the hub. You can use a specific commit id or a branch like `refs/pr/2` [env: REVISION=]
      [1m--validation-workers[0m <VALIDATION_WORKERS>
          The number of tokenizer workers used for payload validation and truncation inside the router [env: VALIDATION_WORKERS=] [default: 2]
      [1m--sharded[0m <SHARDED>
          Whether to shard the model across multiple GPUs By default text-generation-inference will

We can launch directly from the notebook since we dont need the command to be interactive.

In [2]:
!RUST_BACKTRACE=1 \
text-generation-launcher \
--model-id astronomer/Llama-3-8B-Instruct-GPTQ-8-Bit \
--quantize gptq \
--max-input-tokens 3000 \
--max-total-tokens 3300 \
--max-batch-size 5 \
--max-total-tokens 100000 \
--hostname 0.0.0.0 \
--port 1337 # We need to change it from the default to play well with spaces 



zsh:1: command not found: text-generation-launcher


In [None]:
import subprocess
import time

def launch_server(tokens):
    try:
        command = [
            'text-generation-launcher',
            '--model-id', 'astronomer/Llama-3-8B-Instruct-GPTQ-8-Bit',
            '--quantize', 'gptq',
            '--max-input-tokens', '3000',
            '--max-batch-size', '5',
            '--max-batch-total-tokens', f'{tokens}',
            '--hostname', '0.0.0.0',
            '--port', '1337',
        ]
        
        # Launch the subprocess with a text command (without shell=True for safety)
        process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        
        # Set a time limit to wait for the process to stabilize
        time_limit = 120
        start_time = time.time()
        
        while time.time() - start_time < time_limit:
            output = process.stdout.readline()
            if "Connected" in output:
                print(f"Success message found for {tokens} tokens.")
                process.terminate()  # Gracefully terminate if successful
                return True
            if "RuntimeError" in output or "OutOfMemory" in output:
                print(f"Failure message found for {tokens} tokens.")
                process.terminate()
                return False
        
        # If no specific message was found but the process is still running
        if process.poll() is None:
            print(f"No specific message but process is still running for {tokens} tokens.")
            process.terminate()
            return True
        else:
            return False
    except Exception as e:
        print(f"Error launching server with {tokens} tokens: {e}")
        return False

In [None]:
%%time
def find_max_prefill_tokens():
    low, high = 0, 1_000_000  # Adjust the upper bound as needed
    best_valid = 0
    
    while low <= high:
        mid = (low + high) // 2
        print(f"Testing with {mid} max-total-tokens...")
        
        if launch_server(mid):
            print(f"Success with {mid} max-total-tokens.")
            best_valid = mid  # Update best known valid count
            low = mid + 1
        else:
            print(f"Failed with {mid} max-total-tokens.")
            high = mid - 1
    
    print(f"Maximum manageable prefill tokens: {best_valid}")
    return best_valid

# Call the function
max_tokens = find_max_prefill_tokens()