本文记录使用 musubi-tuner 与 RTX4090 训练 Flux Kontext 人物换脸 Lora 流程。与之前 Flux Lora 训练 不同, Flux Kontext Lora 需要增加多一张参考图片,训练脚本框架也从 sd-scripts 换成了 musubi-tuner。
工具配置
因为是同一个作者,因此 musubi-tuner 的安装配置与 sd-scripts 类似,使用以下命令:
git clone https://github.com/kohya-ss/musubi-tuner.git
cd musubi-tuner
python3.10 -m venv venv
source venv/bin/activate
pip install -i https://mirrors.aliyun.com/pypi/simple --extra-index-url https://download.pytorch.org/whl/cu124 --default-timeout=100 -r requirements.txt
配置 accelerate
accelerate config
启用 numa efficiency, 并且选择 bf16,生成的 ~/.cache/huggingface/accelerate/default_config.yaml 内容
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: 'NO'
downcast_bf16: 'no'
enable_cpu_affinity: true
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
数据集准备
dataset.toml 与 Flux Lora 类似, 但由于 Flux Kontext 需要增加多一张参考图片,因此使用 metadata.jsonl 打 label。
[general]
resolution = [1024, 1024]
batch_size = 1
enable_bucket = true
bucket_no_upscale = false
[[datasets]]
image_jsonl_file = "/data/musubi-tuner/data/kontext_face_restoration/dataset_base/metadata.jsonl"
cache_directory = "/data/musubi-tuner/data/kontext_face_restoration/cache/dataset_base"
num_repeats = 2
flux_kontext_no_resize_control = false
[[datasets]]
image_jsonl_file = "/data/musubi-tuner/data/kontext_face_restoration/dataset_style/metadata.jsonl"
cache_directory = "/data/musubi-tuner/data/kontext_face_restoration/cache/dataset_style"
num_repeats = 2
flux_kontext_no_resize_control = false
metadata.jsonl 内容如下:
{"image_path": "/data/musubi-tuner/data/kontext_face_restoration/dataset_base/14_end.png", "control_path": "/data/musubi-tuner/data/kontext_face_restoration/dataset_base/14_start.png", "caption": "face restoration"}
其中 image_path 为目标图片,即换脸后的图片,control_path 为参考图片,即换脸工作流的输入图片,caption 为图片描述。
针对我们的换脸模型,control_path 就是把人脸 A 叠加在人物 B 上的图片,而 image_path 就是人物 A 的换脸后的自然图片。
为了增加模型稳定性与泛化能力,可以增加不同风格间的换脸数据集,或者人脸朝向角度稍微不同的数据集。
训练
musubi_tuner 增加了数据预处理过程。
python src/musubi_tuner/flux_kontext_cache_latents.py --vae /data/models/FLUX.1-Kontext-dev/ae.safetensors --dataset_config /data/musubi-tuner/data/kontext_face_restoration/dataset.toml
python src/musubi_tuner/flux_kontext_cache_text_encoder_outputs.py --text_encoder1 /data/models/flux_text_encoders/t5xxl_fp16.safetensors --text_encoder2 /data/models/flux_text_encoders/clip_l.safetensors --batch_size 2 --dataset_config /data/musubi-tuner/data/kontext_face_restoration/dataset.toml
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 4 src/musubi_tuner/flux_kontext_train_network.py --config_file /data/musubi-tuner/data/kontext_face_restoration/config.toml
在此使用的 config.toml 内容如下:
dit = "/data/models/FLUX.1-Kontext-dev/flux1-kontext-dev.safetensors"
vae = "/data/models/FLUX.1-Kontext-dev/ae.safetensors"
text_encoder1 = "/data/models/flux_text_encoders/t5xxl_fp16.safetensors"
text_encoder2 = "/data/models/flux_text_encoders/clip_l.safetensors"
dataset_config = "/data/musubi-tuner/data/kontext_face_restoration/dataset.toml"
output_dir = "/data/musubi-tuner/data/kontext_face_restoration/output"
output_name = "face_restoration"
save_model_as = "safetensors"
sdpa = true
mixed_precision = "bf16"
save_precision = "bf16"
timestep_sampling = "flux_shift"
weighting_scheme = "none"
optimizer_type = "adamw8bit"
learning_rate = 0.0002
gradient_checkpointing = true
max_data_loader_n_workers = 2
persistent_data_loader_workers = true
network_module = "networks.lora_flux"
network_dim = 64
network_alpha = 32
max_train_epochs = 100
save_every_n_epochs = 20
seed = 9527
fp8_base = true
fp8_scaled = true