v5e Cloud TPU VM에서 JetStream MaxText 추론


JetStream은 XLA 기기(TPU)에서 대규모 언어 모델(LLM) 추론을 위한 처리량 및 메모리 최적화 엔진입니다.

시작하기 전에

TPU 리소스 관리의 단계에 따라 --accelerator-type에서 v5litepod-8로 설정하는 TPU VM을 만들고 TPU VM에 연결합니다.

JetStream 및 MaxText 설정

  1. JetStream 및 MaxText GitHub 저장소 다운로드

       git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git
       git clone -b v0.2.2 https://github.com/google/JetStream.git
    
  2. MaxText 설정

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh
    

모델 체크포인트 변환

JetStream MaxText 서버를 Gemma 또는 Llama2 모델과 함께 실행할 수 있습니다. 이 섹션에서는 이러한 모델의 다양한 크기로 JetStream MaxText 서버를 실행하는 방법을 설명합니다.

Gemma 모델 체크포인트 사용

  1. Kaggle에서 Gemma 체크포인트를 다운로드합니다.
  2. 체크포인트를 Cloud Storage 버킷에 복사합니다.

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
    

    ${YOUR_CKPT_PATH}${CHKPT_BUCKET}의 값이 포함된 예시는 변환 스크립트를 참조하세요.

  3. Gemma 체크포인트를 MaxText와 호환되는 스캔되지 않은 체크포인트로 변환합니다.

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
    

Llama2 모델 체크포인트 사용

  1. 오픈소스 커뮤니티에서 Llama2 체크포인트를 다운로드하거나 사용자가 만든 체크포인트를 사용하세요.

  2. 체크포인트를 Cloud Storage 버킷에 복사합니다.

       gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
    

    ${YOUR_CKPT_PATH}${CHKPT_BUCKET}의 값이 포함된 예시는 변환 스크립트를 참조하세요.

  3. Llama2 체크포인트를 MaxText와 호환되는 스캔되지 않은 체크포인트로 변환합니다.

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
    

JetStream MaxText 서버 실행

이 섹션에서는 MaxText 호환 체크포인트를 사용하여 MaxText 서버를 실행하는 방법을 설명합니다.

MaxText 서버용 환경 변수 구성

사용 중인 모델을 기준으로 다음 환경 변수를 내보냅니다. model_ckpt_conversion.sh의 출력에서 UNSCANNED_CKPT_PATH에 이 값을 사용합니다.

서버 플래그용 Gemma-7b 환경 변수 만들기

JetStream MaxText 서버 플래그를 구성합니다.

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

서버 플래그용 Llama2-7b 환경 변수 만들기

JetStream MaxText 서버 플래그를 구성합니다.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

서버 플래그용 Llama2-13b 환경 변수 만들기

JetStream MaxText 서버 플래그를 구성합니다.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

JetStream MaxText 서버 시작

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

JetStream MaxText Server 플래그 설명

tokenizer_path
tokenizer의 경로입니다(모델과 일치해야 함).
load_parameters_path
특정 디렉터리에서 매개변수(최적화 도구 상태 없음)를 로드합니다.
per_device_batch_size
기기당 디코딩 배치 크기(TPU 칩 1개 = 기기 1개)
max_prefill_predict_length
자동 회귀 수행 시 미리 입력의 최대 길이
max_target_length
최대 시퀀스 길이
model_name
모델 이름
ici_fsdp_parallelism
FSDP 동시 로드의 샤드 수
ici_autoregressive_parallelism
: 자동 회귀 동시 로드의 샤드 수
ici_tensor_parallelism
텐서 동시 로드의 샤드 수
weight_dtype
가중치 데이터 유형(예: bfloat16)
scan_layers
스캔 레이어 불리언 플래그(추론을 위해 `false` 로 설정)

JetStream MaxText 서버에 테스트 요청 전송

cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2

출력은 다음과 비슷하게 표시됩니다.

Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response:  to be a fan

JetStream MaxText 서버로 벤치마크 실행

최상의 벤치마크 결과를 ��으려면 가중치와 KV 캐시 모두에 대해 양자화(정확도 보장을 위해 AQT 훈련 또는 미세 조정된 체크포인트 사용)를 사용 설정합니다. 양자화를 사용 설정하려면 양자화 플래그를 설정합니다.

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

Gemma-7b 벤치마킹

Gemma-7b를 벤치마킹하려면 다음 단계를 따르세요.

  1. ShareGPT 데이터 세트를 다운로드합니다.
  2. Gemma 7b를 실행할 때는 Gemma tokenizer(tokenizer.gemma)를 사용해야 합니다.
  3. 첫 번째 실행에서 --warmup-first 플래그를 추가하여 서버를 워밍업합니다.
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

더 큰 Llama2 벤치마킹

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

삭제

이 가이드에서 사용된 리소스 비용이 Google Cloud 계정에 청구되지 않도록 하려면 리소스가 포함된 프로젝트를 삭제하거나 프로젝트를 유지하고 개별 리소스를 삭제하세요.

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env