From f841d591e62208bba1c0a895889c4f8f37f001b1 Mon Sep 17 00:00:00 2001 From: noopy Date: Wed, 1 Jan 2025 21:12:14 +0900 Subject: [PATCH 1/2] fix pretrained parameter assignment in text classification with flax --- examples/text_classification_flax.ipynb | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/text_classification_flax.ipynb b/examples/text_classification_flax.ipynb index 24dbf775..059a3969 100644 --- a/examples/text_classification_flax.ipynb +++ b/examples/text_classification_flax.ipynb @@ -1481,6 +1481,9 @@ }, "outputs": [], "source": [ + "unreplicated_params = flax.jax_utils.unreplicate(state.params)\n", + "model.params = jax.device_get(jax.tree_map(lambda x: x.astype(jnp.float32), unreplicated_params))\n", + "\n", "model.push_to_hub(model_id, use_auth_token=hf_auth_token)\n", "tokenizer.push_to_hub(model_id, use_auth_token=hf_auth_token)" ] From 190ff5b6c96429768841dba686bc022177bba318 Mon Sep 17 00:00:00 2001 From: noopy Date: Wed, 1 Jan 2025 21:45:42 +0900 Subject: [PATCH 2/2] add pretrained LM parameter uploading code to huggingface hub --- examples/causal_language_modeling_flax.ipynb | 165 ++++++++++++------- 1 file changed, 105 insertions(+), 60 deletions(-) diff --git a/examples/causal_language_modeling_flax.ipynb b/examples/causal_language_modeling_flax.ipynb index 4dfe0484..2935046d 100644 --- a/examples/causal_language_modeling_flax.ipynb +++ b/examples/causal_language_modeling_flax.ipynb @@ -1550,9 +1550,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Train... (1/10 | Loss: 6.935000419616699, Learning Rate: 0.0002699999895412475)\n" ] }, @@ -1576,9 +1576,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Eval... (1/10 | Loss: 7.108445644378662 | Perplexity: 1246.529052734375)\n" ] }, @@ -1602,9 +1602,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Train... (2/10 | Loss: 6.334000110626221, Learning Rate: 0.00023999999393709004)\n" ] }, @@ -1628,9 +1628,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Eval... (2/10 | Loss: 6.567610740661621 | Perplexity: 738.8753662109375)\n" ] }, @@ -1654,9 +1654,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Train... (3/10 | Loss: 5.798000335693359, Learning Rate: 0.0002099999983329326)\n" ] }, @@ -1680,9 +1680,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Eval... (3/10 | Loss: 6.278167247772217 | Perplexity: 557.9488525390625)\n" ] }, @@ -1706,9 +1706,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Train... (4/10 | Loss: 5.557000160217285, Learning Rate: 0.00018000000272877514)\n" ] }, @@ -1732,9 +1732,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Eval... (4/10 | Loss: 6.062875270843506 | Perplexity: 451.3289794921875)\n" ] }, @@ -1758,9 +1758,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Train... (5/10 | Loss: 5.543000221252441, Learning Rate: 0.00014999999257270247)\n" ] }, @@ -1784,9 +1784,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Eval... (5/10 | Loss: 5.920379161834717 | Perplexity: 392.97332763671875)\n" ] }, @@ -1810,9 +1810,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Train... (6/10 | Loss: 5.361000061035156, Learning Rate: 0.00011999999696854502)\n" ] }, @@ -1836,9 +1836,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Eval... (6/10 | Loss: 5.821027755737305 | Perplexity: 356.4353942871094)\n" ] }, @@ -1862,9 +1862,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Train... (7/10 | Loss: 5.207000255584717, Learning Rate: 9.000000136438757e-05)\n" ] }, @@ -1888,9 +1888,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Eval... (7/10 | Loss: 5.748736381530762 | Perplexity: 332.1453857421875)\n" ] }, @@ -1914,9 +1914,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Train... (8/10 | Loss: 5.124000072479248, Learning Rate: 5.999999848427251e-05)\n" ] }, @@ -1940,9 +1940,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Eval... (8/10 | Loss: 5.703180313110352 | Perplexity: 317.5106201171875)\n" ] }, @@ -1966,9 +1966,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Train... (9/10 | Loss: 5.220000267028809, Learning Rate: 2.9999999242136255e-05)\n" ] }, @@ -1992,9 +1992,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Eval... (9/10 | Loss: 5.674434185028076 | Perplexity: 308.7478942871094)\n" ] }, @@ -2018,9 +2018,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Train... (10/10 | Loss: 4.992000102996826, Learning Rate: 0.0)\n" ] }, @@ -2044,9 +2044,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - "\r", - "\r", + "\r\n", + "\r\n", + "\r\n", "Eval... (10/10 | Loss: 5.66389274597168 | Perplexity: 305.58953857421875)\n", "\n" ] @@ -2098,6 +2098,51 @@ "\n", "For a more in-detail comparison of runtimes please refer to [this](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#runtime-evaluation) table." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You may upload the model to the huggingface hub. The model will be uploaded under `https://huggingface.co//`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "unreplicated_params = flax.jax_utils.unreplicate(state.params)\n", + "model.params = jax.device_get(jax.tree_map(lambda x: x.astype(jnp.float32), unreplicated_params))\n", + "\n", + "model.push_to_hub(model_dir)\n", + "tokenizer.push_to_hub(model_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After uploading, you may use the model to generate a text using pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use a pipeline as a high-level helper\n", + "from transformers import pipeline\n", + "\n", + "your_username = \"\" # huggingface user name\n", + "pipe = pipeline(\"text-generation\", model=f\"{your_username}/{model_dir}\")\n", + "\n", + "sample_prompt = \"Týðingin verður løgd til almennar\" # sample inputs\n", + "\n", + "# decoding with beam search of 5\n", + "pipe(sample_prompt, max_new_tokens=20, num_beams=5)" + ] } ], "metadata": {