shango/glm-trainer/glm_trainer.py

137 lines
4.0 KiB
Python

#!/usr/bin/env python3
"""
glm_trainer.py — Fine-tune local de modèles GLM/LLaMA pour les tasks Shango
"""
import argparse, json, os, subprocess, sys
from pathlib import Path
DEFAULT_CONFIG = {
"model": "THUDM/chatglm3-6b", # ou un modèle local GGUF
"dataset": "./training_data.jsonl",
"output": "./shango-model",
"epochs": 3,
"batch_size": 4,
"learning_rate": 2e-5,
"max_seq_length": 512,
"lora_r": 16,
"lora_alpha": 32,
"quantization": "4bit", # ou 8bit, none
}
def check_deps() -> bool:
try:
import torch, transformers, peft, datasets
print(f"[GLM] PyTorch {torch.__version__}, CUDA={torch.cuda.is_available()}")
return True
except ImportError as e:
print(f"[GLM] Manque dépendance: {e}")
print("[GLM] Install: pip install torch transformers peft datasets bitsandbytes")
return False
def generate_training_script(config: dict) -> str:
return f'''
import json, torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import Dataset
# Config
config = json.loads(r\'\'\'{json.dumps(config)}\'\'\')
# Load model (quantized)
model_id = config["model"]
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
load_in_4bit=config["quantization"] == "4bit"
)
model = prepare_model_for_kbit_training(model)
# LoRA
lora_config = LoraConfig(
r=config["lora_r"],
lora_alpha=config["lora_alpha"],
target_modules=["query_key_value"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# Dataset
data = []
with open(config["dataset"]) as f:
for line in f:
data.append(json.loads(line))
dataset = Dataset.from_list(data)
def tokenize(ex):
prompt = ex.get("prompt", "")
completion = ex.get("completion", "")
return tokenizer(prompt + completion, truncation=True, max_length=config["max_seq_length"])
tokenized = dataset.map(tokenize, batched=True)
# Train
training_args = TrainingArguments(
output_dir=config["output"],
num_train_epochs=config["epochs"],
per_device_train_batch_size=config["batch_size"],
learning_rate=config["learning_rate"],
logging_steps=10,
save_strategy="epoch",
fp16=True,
)
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized,
)
trainer.train()
model.save_pretrained(config["output"])
tokenizer.save_pretrained(config["output"])
print(f"[GLM] Modèle sauvé dans {{config['output']}}")
'''
def main():
parser = argparse.ArgumentParser(prog="glm-trainer")
parser.add_argument("command", choices=["train", "config", "check"])
parser.add_argument("--dataset", default="./training_data.jsonl")
parser.add_argument("--output", default="./shango-model")
parser.add_argument("--model", default=DEFAULT_CONFIG["model"])
parser.add_argument("--epochs", type=int, default=3)
args = parser.parse_args()
if args.command == "check":
check_deps()
return
if args.command == "config":
print(json.dumps(DEFAULT_CONFIG, indent=2))
return
if args.command == "train":
if not check_deps():
sys.exit(1)
config = DEFAULT_CONFIG.copy()
config.update({
"dataset": args.dataset,
"output": args.output,
"model": args.model,
"epochs": args.epochs,
})
script = generate_training_script(config)
script_path = Path(args.output) / "train_script.py"
script_path.parent.mkdir(parents=True, exist_ok=True)
script_path.write_text(script)
print(f"[GLM] Script généré: {script_path}")
print(f"[GLM] Lancer: cd {args.output} && python train_script.py")
if __name__ == "__main__":
main()