137 lines
4.0 KiB
Python
137 lines
4.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
glm_trainer.py — Fine-tune local de modèles GLM/LLaMA pour les tasks Shango
|
|
"""
|
|
import argparse, json, os, subprocess, sys
|
|
from pathlib import Path
|
|
|
|
DEFAULT_CONFIG = {
|
|
"model": "THUDM/chatglm3-6b", # ou un modèle local GGUF
|
|
"dataset": "./training_data.jsonl",
|
|
"output": "./shango-model",
|
|
"epochs": 3,
|
|
"batch_size": 4,
|
|
"learning_rate": 2e-5,
|
|
"max_seq_length": 512,
|
|
"lora_r": 16,
|
|
"lora_alpha": 32,
|
|
"quantization": "4bit", # ou 8bit, none
|
|
}
|
|
|
|
def check_deps() -> bool:
|
|
try:
|
|
import torch, transformers, peft, datasets
|
|
print(f"[GLM] PyTorch {torch.__version__}, CUDA={torch.cuda.is_available()}")
|
|
return True
|
|
except ImportError as e:
|
|
print(f"[GLM] Manque dépendance: {e}")
|
|
print("[GLM] Install: pip install torch transformers peft datasets bitsandbytes")
|
|
return False
|
|
|
|
def generate_training_script(config: dict) -> str:
|
|
return f'''
|
|
import json, torch
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
|
from datasets import Dataset
|
|
|
|
# Config
|
|
config = json.loads(r\'\'\'{json.dumps(config)}\'\'\')
|
|
|
|
# Load model (quantized)
|
|
model_id = config["model"]
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
torch_dtype=torch.float16,
|
|
device_map="auto",
|
|
trust_remote_code=True,
|
|
load_in_4bit=config["quantization"] == "4bit"
|
|
)
|
|
model = prepare_model_for_kbit_training(model)
|
|
|
|
# LoRA
|
|
lora_config = LoraConfig(
|
|
r=config["lora_r"],
|
|
lora_alpha=config["lora_alpha"],
|
|
target_modules=["query_key_value"],
|
|
lora_dropout=0.05,
|
|
bias="none",
|
|
task_type="CAUSAL_LM"
|
|
)
|
|
model = get_peft_model(model, lora_config)
|
|
|
|
# Dataset
|
|
data = []
|
|
with open(config["dataset"]) as f:
|
|
for line in f:
|
|
data.append(json.loads(line))
|
|
dataset = Dataset.from_list(data)
|
|
|
|
def tokenize(ex):
|
|
prompt = ex.get("prompt", "")
|
|
completion = ex.get("completion", "")
|
|
return tokenizer(prompt + completion, truncation=True, max_length=config["max_seq_length"])
|
|
|
|
tokenized = dataset.map(tokenize, batched=True)
|
|
|
|
# Train
|
|
training_args = TrainingArguments(
|
|
output_dir=config["output"],
|
|
num_train_epochs=config["epochs"],
|
|
per_device_train_batch_size=config["batch_size"],
|
|
learning_rate=config["learning_rate"],
|
|
logging_steps=10,
|
|
save_strategy="epoch",
|
|
fp16=True,
|
|
)
|
|
|
|
from transformers import Trainer
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=tokenized,
|
|
)
|
|
trainer.train()
|
|
model.save_pretrained(config["output"])
|
|
tokenizer.save_pretrained(config["output"])
|
|
print(f"[GLM] Modèle sauvé dans {{config['output']}}")
|
|
'''
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(prog="glm-trainer")
|
|
parser.add_argument("command", choices=["train", "config", "check"])
|
|
parser.add_argument("--dataset", default="./training_data.jsonl")
|
|
parser.add_argument("--output", default="./shango-model")
|
|
parser.add_argument("--model", default=DEFAULT_CONFIG["model"])
|
|
parser.add_argument("--epochs", type=int, default=3)
|
|
args = parser.parse_args()
|
|
|
|
if args.command == "check":
|
|
check_deps()
|
|
return
|
|
|
|
if args.command == "config":
|
|
print(json.dumps(DEFAULT_CONFIG, indent=2))
|
|
return
|
|
|
|
if args.command == "train":
|
|
if not check_deps():
|
|
sys.exit(1)
|
|
config = DEFAULT_CONFIG.copy()
|
|
config.update({
|
|
"dataset": args.dataset,
|
|
"output": args.output,
|
|
"model": args.model,
|
|
"epochs": args.epochs,
|
|
})
|
|
script = generate_training_script(config)
|
|
script_path = Path(args.output) / "train_script.py"
|
|
script_path.parent.mkdir(parents=True, exist_ok=True)
|
|
script_path.write_text(script)
|
|
print(f"[GLM] Script généré: {script_path}")
|
|
print(f"[GLM] Lancer: cd {args.output} && python train_script.py")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|