Skip to main content
Glama

doc-lib-mcp

by shifusen329
gemma3-fintuning.md7.01 kB
# Finetuning This is an example on fine-tuning Gemma. For an example on how to run a pre-trained Gemma model, see the [sampling](https://gemma-llm.readthedocs.io/en/latest/sampling.html) tutorial. To fine-tune Gemma, we use the [kauldron](https://kauldron.readthedocs.io/en/latest/) library which abstract most of the boilerplate (checkpoint management, training loop, evaluation, metric reporting, sharding,…). ``` !pip install -q gemma ``` ``` # Common imports import os import optax import treescope # Gemma imports from kauldron import kd from gemma import gm ``` By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html): ``` os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00" ``` ## Data pipeline First create the tokenizer, as it’s required in the data pipeline. ``` tokenizer = gm.text.Gemma3Tokenizer() tokenizer.encode('This is an example sentence', add_bos=True) ``` ``` [<_Gemma2SpecialTokens.BOS: 2>, 1596, 603, 671, 3287, 13060] ``` First we need a data pipeline. Multiple pipelines are supported including: - [HuggingFace](https://kauldron.readthedocs.io/en/latest/api/kd/data/py/HuggingFace.html) - [TFDS](https://kauldron.readthedocs.io/en/latest/api/kd/data/py/Tfds.html) - [Json](https://kauldron.readthedocs.io/en/latest/api/kd/data/py/Json.html) - … It’s quite simple to add your own data, or to create mixtures from multiple sources. See the [pipeline documentation](https://kauldron.readthedocs.io/en/latest/data_py.html). We use `transforms` to customize the data pipeline, this includes: - Tokenizing the inputs (with [`gm.data.Tokenize`](https://gemma-llm.readthedocs.io/en/latest/api/gm/data/Tokenize.html)) - Creating the model inputs (with [`gm.data.Tokenize`](https://gemma-llm.readthedocs.io/en/latest/api/gm/data/Tokenize.html))) - Adding padding (with [`gm.data.Pad`](https://gemma-llm.readthedocs.io/en/latest/api/gm/data/Pad.html)) (required to batch inputs with different lengths) Note that in practice, you can combine multiple transforms into a higher level transform. See the `gm.data.ContrastiveTask()` transform in the [DPO example](https://github.com/google-deepmind/gemma/tree/main/examples/dpo.py) for an example. Here, we try [mtnt](https://www.tensorflow.org/datasets/catalog/mtnt), a small translation dataset. The dataset structure is `{'src': ..., 'dst': ...}`. ``` ds = kd.data.py.Tfds( name='mtnt/en-fr', split='train', shuffle=True, batch_size=8, transforms=[ # Create the model inputs/targets/loss_mask. gm.data.Seq2SeqTask( # Select which field from the dataset to use. # https://www.tensorflow.org/datasets/catalog/mtnt in_prompt='src', in_response='dst', # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...} out_input='input', out_target='target', out_target_mask='loss_mask', tokenizer=tokenizer, # Padding parameters max_length=200, truncate=True, ), ], ) ex = ds[0] treescope.show(ex) ``` ``` Disabling pygrain multi-processing (unsupported in colab). { 'input': i64[8 200], 'loss_mask': bool_[8 200 1], 'target': i64[8 200 1], } ``` We can decode an example from the batch to inspect the model input. We see that the `<start_of_turn>` / `<end_of_turn>` where correctly added to follow Gemma dialog format. ``` text = tokenizer.decode(ex['input'][0]) print(text) ``` ``` <start_of_turn>user Would love any other tips from anyone, but specially from someone who’s been where I’m at.<end_of_turn> <start_of_turn>model J'apprécierais vraiment d'autres astuces, mais particulièrement par quelqu'un qui était était déjà là où je me trouve. ``` ## Trainer The [kauldron](https://kauldron.readthedocs.io/en/latest/) trainer allow to train Gemma simply by providing a dataset, a model, a loss and an optimizer. Dataset, model and losses are connected together through a `key` strings system. For more information, see the [key documentation](https://kauldron.readthedocs.io/en/latest/intro.html#keys-and-context). Each key starts by a registered prefix. Common prefixes includes: - `batch`: The output of the dataset (after all transformations). Here our batch is `{'input': ..., 'loss_mask': ..., 'target': ...}` - `preds`: The output of the model. For Gemma models, this is `gm.nn.Output(logits=..., cache=...)` - `params`: Model parameters (can be used to add a weight decay loss, or monitor the params norm in metrics) ``` model = gm.nn.Gemma3_4B( tokens="batch.input", ) ``` ``` loss = kd.losses.SoftmaxCrossEntropyWithIntLabels( logits="preds.logits", labels="batch.target", mask="batch.loss_mask", ) ``` We then create the trainer: ``` trainer = kd.train.Trainer( seed=42, # The seed of enlightenment workdir='/tmp/ckpts', # TODO(epot): Make the workdir optional by default # Dataset train_ds=ds, # Model model=model, init_transform=gm.ckpts.LoadCheckpoint( # Load the weights from the pretrained checkpoint path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT, ), # Training parameters num_train_steps=300, train_losses={"loss": loss}, optimizer=optax.adafactor(learning_rate=1e-3), ) ``` Trainning can be launched with the `.train()` method. Note that the trainer like the model are immutables, so it does not store the state nor params. Instead the state containing the trained parameters is returned. ``` state, aux = trainer.train() ``` ``` Configuring ... Initializing ... Starting training loop at step 0 ``` ## Checkpointing To save the model params, you can either: - Activate checkpointing in the trainer by adding: ``` trainer = kd.train.Trainer( workdir='/tmp/my_experiment/', checkpointer=kd.ckpts.Checkpointer( save_interval_steps=500, ), ... ) ``` This will also save the optimizer, step, dataset state,… - Manually save the trained params: ``` gm.ckpts.save_params(state.params, '/tmp/my_ckpt/') ``` ## Evaluation Here, we only perform a qualitative evaluation by sampling a prompt. For more info on evals: - See the [sampling](https://gemma-llm.readthedocs.io/en/latest/sampling.html) tutorial for more info on running inference. - To add evals during training, see the Kauldron [evaluator](https://kauldron.readthedocs.io/en/latest/eval.html) documentation. ``` sampler = gm.text.ChatSampler( model=model, params=state.params, tokenizer=tokenizer, ) ``` We test a sentence, using the same formatting used during fine-tuning: ``` sampler.chat('Hello! My next holidays are in Paris.') ``` ``` 'Salut ! Mes vacances suivantes seront à Paris.' ``` The model correctly translated our prompt to French!

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/shifusen329/doc-lib-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server