Merge remote-tracking branch 'upstream/main'
Some checks failed
ci/woodpecker/push/build Pipeline failed
ci/woodpecker/manual/build Pipeline failed
ci/woodpecker/tag/build Pipeline is running

This commit is contained in:
ryan 2023-09-22 15:20:23 -04:00
commit 77eac39bc8
223 changed files with 10195 additions and 3963 deletions

3
.github/pull_request_template.md vendored Normal file
View file

@ -0,0 +1,3 @@
## Checklist:
- [ ] I have read the [Contributing guidelines](https://github.com/oobabooga/text-generation-webui/wiki/Contributing-guidelines).

View file

@ -13,8 +13,8 @@ jobs:
- uses: actions/stale@v5 - uses: actions/stale@v5
with: with:
stale-issue-message: "" stale-issue-message: ""
close-issue-message: "This issue has been closed due to inactivity for 30 days. If you believe it is still relevant, please leave a comment below." close-issue-message: "This issue has been closed due to inactivity for 6 weeks. If you believe it is still relevant, please leave a comment below. You can tag a developer in your comment."
days-before-issue-stale: 30 days-before-issue-stale: 42
days-before-issue-close: 0 days-before-issue-close: 0
stale-issue-label: "stale" stale-issue-label: "stale"
days-before-pr-stale: -1 days-before-pr-stale: -1

186
README.md
View file

@ -4,30 +4,28 @@
# Text generation web UI # Text generation web UI
A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, Pythia, OPT, and GALACTICA. A Gradio web UI for Large Language Models.
Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation. Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
|![Image1](https://github.com/oobabooga/screenshots/raw/main/qa.png) | ![Image2](https://github.com/oobabooga/screenshots/raw/main/cai3.png) | |![Image1](https://github.com/oobabooga/screenshots/raw/main/print_instruct.png) | ![Image2](https://github.com/oobabooga/screenshots/raw/main/print_chat.png) |
|:---:|:---:| |:---:|:---:|
|![Image3](https://github.com/oobabooga/screenshots/raw/main/gpt4chan.png) | ![Image4](https://github.com/oobabooga/screenshots/raw/main/galactica.png) | |![Image1](https://github.com/oobabooga/screenshots/raw/main/print_default.png) | ![Image2](https://github.com/oobabooga/screenshots/raw/main/print_parameters.png) |
## Features ## Features
* 3 interface modes: default, notebook, and chat * 3 interface modes: default (two columns), notebook, and chat
* Multiple model backends: tranformers, llama.cpp, AutoGPTQ, GPTQ-for-LLaMa, ExLlama, RWKV, FlexGen * Multiple model backends: [transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp), [ExLlama](https://github.com/turboderp/exllama), [ExLlamaV2](https://github.com/turboderp/exllamav2), [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ), [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa), [CTransformers](https://github.com/marella/ctransformers)
* Dropdown menu for quickly switching between different models * Dropdown menu for quickly switching between different models
* LoRA: load and unload LoRAs on the fly, load multiple LoRAs at the same time, train a new LoRA * LoRA: load and unload LoRAs on the fly, train a new LoRA using QLoRA
* Precise instruction templates for chat mode, including Alpaca, Vicuna, Open Assistant, Dolly, Koala, ChatGLM, MOSS, RWKV-Raven, Galactica, StableLM, WizardLM, Baize, Ziya, Chinese-Vicuna, MPT, INCITE, Wizard Mega, KoAlpaca, Vigogne, Bactrian, h2o, and OpenBuddy * Precise instruction templates for chat mode, including Llama-2-chat, Alpaca, Vicuna, WizardLM, StableLM, and many others
* 4-bit, 8-bit, and CPU inference through the transformers library
* Use llama.cpp models with transformers samplers (`llamacpp_HF` loader)
* [Multimodal pipelines, including LLaVA and MiniGPT-4](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal) * [Multimodal pipelines, including LLaVA and MiniGPT-4](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal)
* 8-bit and 4-bit inference through bitsandbytes * [Extensions framework](docs/Extensions.md)
* CPU mode for transformers models
* [DeepSpeed ZeRO-3 inference](docs/DeepSpeed.md)
* [Extensions](docs/Extensions.md)
* [Custom chat characters](docs/Chat-mode.md) * [Custom chat characters](docs/Chat-mode.md)
* Very efficient text streaming * Very efficient text streaming
* Markdown output with LaTeX rendering, to use for instance with [GALACTICA](https://github.com/paperswithcode/galai) * Markdown output with LaTeX rendering, to use for instance with [GALACTICA](https://github.com/paperswithcode/galai)
* Nice HTML output for GPT-4chan
* API, including endpoints for websocket streaming ([see the examples](https://github.com/oobabooga/text-generation-webui/blob/main/api-examples)) * API, including endpoints for websocket streaming ([see the examples](https://github.com/oobabooga/text-generation-webui/blob/main/api-examples))
To learn how to use the various features, check out the Documentation: https://github.com/oobabooga/text-generation-webui/tree/main/docs To learn how to use the various features, check out the Documentation: https://github.com/oobabooga/text-generation-webui/tree/main/docs
@ -42,26 +40,24 @@ To learn how to use the various features, check out the Documentation: https://g
Just download the zip above, extract it, and double-click on "start". The web UI and all its dependencies will be installed in the same folder. Just download the zip above, extract it, and double-click on "start". The web UI and all its dependencies will be installed in the same folder.
* The source codes are here: https://github.com/oobabooga/one-click-installers * The source codes and more information can be found here: https://github.com/oobabooga/one-click-installers
* There is no need to run the installers as admin. * There is no need to run the installers as admin.
* AMD doesn't work on Windows.
* Huge thanks to [@jllllll](https://github.com/jllllll), [@ClayShoaf](https://github.com/ClayShoaf), and [@xNul](https://github.com/xNul) for their contributions to these installers. * Huge thanks to [@jllllll](https://github.com/jllllll), [@ClayShoaf](https://github.com/ClayShoaf), and [@xNul](https://github.com/xNul) for their contributions to these installers.
### Manual installation using Conda ### Manual installation using Conda
Recommended if you have some experience with the command line. Recommended if you have some experience with the command-line.
#### 0. Install Conda #### 0. Install Conda
https://docs.conda.io/en/latest/miniconda.html https://docs.conda.io/en/latest/miniconda.html
On Linux or WSL, it can be automatically installed with these two commands: On Linux or WSL, it can be automatically installed with these two commands ([source](https://educe-ubc.github.io/conda.html)):
``` ```
curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh" curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh"
bash Miniconda3.sh bash Miniconda3.sh
``` ```
Source: https://educe-ubc.github.io/conda.html
#### 1. Create a new conda environment #### 1. Create a new conda environment
@ -75,17 +71,14 @@ conda activate textgen
| System | GPU | Command | | System | GPU | Command |
|--------|---------|---------| |--------|---------|---------|
| Linux/WSL | NVIDIA | `pip3 install torch torchvision torchaudio` | | Linux/WSL | NVIDIA | `pip3 install torch torchvision torchaudio` |
| Linux/WSL | CPU only | `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu` |
| Linux | AMD | `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2` | | Linux | AMD | `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2` |
| MacOS + MPS (untested) | Any | `pip3 install torch torchvision torchaudio` | | MacOS + MPS | Any | `pip3 install torch torchvision torchaudio` |
| Windows | NVIDIA | `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117` | | Windows | NVIDIA | `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117` |
| Windows | CPU only | `pip3 install torch torchvision torchaudio` |
The up-to-date commands can be found here: https://pytorch.org/get-started/locally/. The up-to-date commands can be found here: https://pytorch.org/get-started/locally/.
#### 2.1 Special instructions
* MacOS users: https://github.com/oobabooga/text-generation-webui/pull/393
* AMD users: https://rentry.org/eq3hg
#### 3. Install the web UI #### 3. Install the web UI
``` ```
@ -94,13 +87,30 @@ cd text-generation-webui
pip install -r requirements.txt pip install -r requirements.txt
``` ```
#### llama.cpp with GPU acceleration #### AMD, Metal, Intel Arc, and CPUs without AVX2
Requires the additional compilation step described here: [GPU acceleration](https://github.com/oobabooga/text-generation-webui/blob/main/docs/llama.cpp-models.md#gpu-acceleration). 1) Replace the last command above with
#### bitsandbytes ```
pip install -r requirements_nocuda.txt
```
bitsandbytes >= 0.39 may not work on older NVIDIA GPUs. In that case, to use `--load-in-8bit`, you may have to downgrade like this: 2) Manually install llama-cpp-python using the appropriate command for your hardware: [Installation from PyPI](https://github.com/abetlen/llama-cpp-python#installation-from-pypi).
3) Do the same for CTransformers: [Installation](https://github.com/marella/ctransformers#installation).
4) AMD: Manually install AutoGPTQ: [Installation](https://github.com/PanQiWei/AutoGPTQ#installation).
5) AMD: Manually install [ExLlama](https://github.com/turboderp/exllama) by simply cloning it into the `repositories` folder (it will be automatically compiled at runtime after that):
```
cd text-generation-webui
git clone https://github.com/turboderp/exllama repositories/exllama
```
#### bitsandbytes on older NVIDIA GPUs
bitsandbytes >= 0.39 may not work. In that case, to use `--load-in-8bit`, you may have to downgrade like this:
* Linux: `pip install bitsandbytes==0.38.1` * Linux: `pip install bitsandbytes==0.38.1`
* Windows: `pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl` * Windows: `pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl`
@ -119,37 +129,48 @@ docker compose up --build
### Updating the requirements ### Updating the requirements
From time to time, the `requirements.txt` changes. To update, use this command: From time to time, the `requirements.txt` changes. To update, use these commands:
``` ```
conda activate textgen conda activate textgen
cd text-generation-webui cd text-generation-webui
pip install -r requirements.txt --upgrade pip install -r requirements.txt --upgrade
``` ```
## Downloading models ## Downloading models
Models should be placed inside the `models/` folder. Models should be placed in the `text-generation-webui/models` folder. They are usually downloaded from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads).
[Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads) is the main place to download models. These are some examples: * Transformers or GPTQ models are made of several files and must be placed in a subfolder. Example:
* [Pythia](https://huggingface.co/models?sort=downloads&search=eleutherai%2Fpythia+deduped) ```
* [OPT](https://huggingface.co/models?search=facebook/opt) text-generation-webui
* [GALACTICA](https://huggingface.co/models?search=facebook/galactica) ├── models
* [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B/tree/main) │   ├── lmsys_vicuna-33b-v1.3
│   │   ├── config.json
│   │   ├── generation_config.json
│   │   ├── pytorch_model-00001-of-00007.bin
│   │   ├── pytorch_model-00002-of-00007.bin
│   │   ├── pytorch_model-00003-of-00007.bin
│   │   ├── pytorch_model-00004-of-00007.bin
│   │   ├── pytorch_model-00005-of-00007.bin
│   │   ├── pytorch_model-00006-of-00007.bin
│   │   ├── pytorch_model-00007-of-00007.bin
│   │   ├── pytorch_model.bin.index.json
│   │   ├── special_tokens_map.json
│   │   ├── tokenizer_config.json
│   │   └── tokenizer.model
```
You can automatically download a model from HF using the script `download-model.py`: * GGUF models are a single file and should be placed directly into `models`. Example:
python download-model.py organization/model ```
text-generation-webui
├── models
│   ├── llama-2-13b-chat.Q4_K_M.gguf
```
For example: In both cases, you can use the "Model" tab of the UI to download the model from Hugging Face automatically. It is also possible to download via the command-line with `python download-model.py organization/model` (use `--help` to see all the options).
python download-model.py facebook/opt-1.3b
To download a protected model, set env vars `HF_USER` and `HF_PASS` to your Hugging Face username and password (or [User Access Token](https://huggingface.co/settings/tokens)). The model's terms must first be accepted on the HF website.
#### GGML models
You can drop these directly into the `models/` folder, making sure that the file name contains `ggml` somewhere and ends in `.bin`.
#### GPT-4chan #### GPT-4chan
@ -175,7 +196,10 @@ After downloading the model, follow these steps:
python download-model.py EleutherAI/gpt-j-6B --text-only python download-model.py EleutherAI/gpt-j-6B --text-only
``` ```
When you load this model in default or notebook modes, the "HTML" tab will show the generated text in 4chan format. When you load this model in default or notebook modes, the "HTML" tab will show the generated text in 4chan format:
![Image3](https://github.com/oobabooga/screenshots/raw/main/gpt4chan.png)
</details> </details>
## Starting the web UI ## Starting the web UI
@ -195,8 +219,6 @@ Optionally, you can use the following command-line flags:
| Flag | Description | | Flag | Description |
|--------------------------------------------|-------------| |--------------------------------------------|-------------|
| `-h`, `--help` | Show this help message and exit. | | `-h`, `--help` | Show this help message and exit. |
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
| `--chat` | Launch the web UI in chat mode. |
| `--multi-user` | Multi-user mode. Chat histories are not saved or automatically loaded. WARNING: this is highly experimental. | | `--multi-user` | Multi-user mode. Chat histories are not saved or automatically loaded. WARNING: this is highly experimental. |
| `--character CHARACTER` | The name of the character to load in chat mode by default. | | `--character CHARACTER` | The name of the character to load in chat mode by default. |
| `--model MODEL` | Name of the model to load by default. | | `--model MODEL` | Name of the model to load by default. |
@ -204,16 +226,16 @@ Optionally, you can use the following command-line flags:
| `--model-dir MODEL_DIR` | Path to directory with all the models. | | `--model-dir MODEL_DIR` | Path to directory with all the models. |
| `--lora-dir LORA_DIR` | Path to directory with all the loras. | | `--lora-dir LORA_DIR` | Path to directory with all the loras. |
| `--model-menu` | Show a model menu in the terminal when the web UI is first launched. | | `--model-menu` | Show a model menu in the terminal when the web UI is first launched. |
| `--no-stream` | Don't stream the text output in real time. |
| `--settings SETTINGS_FILE` | Load the default interface settings from this yaml file. See `settings-template.yaml` for an example. If you create a file called `settings.yaml`, this file will be loaded by default without the need to use the `--settings` flag. | | `--settings SETTINGS_FILE` | Load the default interface settings from this yaml file. See `settings-template.yaml` for an example. If you create a file called `settings.yaml`, this file will be loaded by default without the need to use the `--settings` flag. |
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. | | `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
| `--verbose` | Print the prompts to the terminal. | | `--verbose` | Print the prompts to the terminal. |
| `--chat-buttons` | Show buttons on chat tab instead of hover menu. |
#### Model loader #### Model loader
| Flag | Description | | Flag | Description |
|--------------------------------------------|-------------| |--------------------------------------------|-------------|
| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, exllama_hf, llamacpp, rwkv, flexgen | | `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, exllama_hf, llamacpp, rwkv, ctransformers |
#### Accelerate/transformers #### Accelerate/transformers
@ -243,18 +265,33 @@ Optionally, you can use the following command-line flags:
| `--quant_type QUANT_TYPE` | quant_type for 4-bit. Valid options: nf4, fp4. | | `--quant_type QUANT_TYPE` | quant_type for 4-bit. Valid options: nf4, fp4. |
| `--use_double_quant` | use_double_quant for 4-bit. | | `--use_double_quant` | use_double_quant for 4-bit. |
#### llama.cpp #### GGUF (for llama.cpp and ctransformers)
| Flag | Description | | Flag | Description |
|-------------|-------------| |-------------|-------------|
| `--threads` | Number of threads to use. | | `--threads` | Number of threads to use. |
| `--n_batch` | Maximum number of prompt tokens to batch together when calling llama_eval. | | `--n_batch` | Maximum number of prompt tokens to batch together when calling llama_eval. |
| `--no-mmap` | Prevent mmap from being used. |
| `--mlock` | Force the system to keep the model in RAM. |
| `--cache-capacity CACHE_CAPACITY` | Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed. |
| `--n-gpu-layers N_GPU_LAYERS` | Number of layers to offload to the GPU. Only works if llama-cpp-python was compiled with BLAS. Set this to 1000000000 to offload all layers to the GPU. | | `--n-gpu-layers N_GPU_LAYERS` | Number of layers to offload to the GPU. Only works if llama-cpp-python was compiled with BLAS. Set this to 1000000000 to offload all layers to the GPU. |
| `--n_ctx N_CTX` | Size of the prompt context. | | `--n_ctx N_CTX` | Size of the prompt context. |
#### llama.cpp
| Flag | Description |
|---------------|---------------|
| `--no-mmap` | Prevent mmap from being used. |
| `--mlock` | Force the system to keep the model in RAM. |
| `--mul_mat_q` | Activate new mulmat kernels. |
| `--cache-capacity CACHE_CAPACITY` | Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed. |
| `--tensor_split TENSOR_SPLIT` | Split the model across multiple GPUs, comma-separated list of proportions, e.g. 18,17 |
| `--llama_cpp_seed SEED` | Seed for llama-cpp models. Default 0 (random). | | `--llama_cpp_seed SEED` | Seed for llama-cpp models. Default 0 (random). |
| `--cpu` | Use the CPU version of llama-cpp-python instead of the GPU-accelerated version. |
|`--cfg-cache` | llamacpp_HF: Create an additional cache for CFG negative prompts. |
#### ctransformers
| Flag | Description |
|-------------|-------------|
| `--model_type MODEL_TYPE` | Model type of pre-quantized model. Currently gpt2, gptj, gptneox, falcon, llama, mpt, starcoder (gptbigcode), dollyv2, and replit are supported. |
#### AutoGPTQ #### AutoGPTQ
@ -265,6 +302,7 @@ Optionally, you can use the following command-line flags:
| `--no_inject_fused_mlp` | Triton mode only: disable the use of fused MLP, which will use less VRAM at the cost of slower inference. | | `--no_inject_fused_mlp` | Triton mode only: disable the use of fused MLP, which will use less VRAM at the cost of slower inference. |
| `--no_use_cuda_fp16` | This can make models faster on some systems. | | `--no_use_cuda_fp16` | This can make models faster on some systems. |
| `--desc_act` | For models that don't have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig. | | `--desc_act` | For models that don't have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig. |
| `--disable_exllama` | Disable ExLlama kernel, which can improve inference speed on some systems. |
#### ExLlama #### ExLlama
@ -272,8 +310,7 @@ Optionally, you can use the following command-line flags:
|------------------|-------------| |------------------|-------------|
|`--gpu-split` | Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. `20,7,7` | |`--gpu-split` | Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. `20,7,7` |
|`--max_seq_len MAX_SEQ_LEN` | Maximum sequence length. | |`--max_seq_len MAX_SEQ_LEN` | Maximum sequence length. |
|`--compress_pos_emb COMPRESS_POS_EMB` | Positional embeddings compression factor. Should typically be set to max_seq_len / 2048. | |`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. |
|`--alpha_value ALPHA_VALUE` | Positional embeddings alpha factor for NTK RoPE scaling. Same as above. Use either this or compress_pos_emb, not both. `
#### GPTQ-for-LLaMa #### GPTQ-for-LLaMa
@ -285,17 +322,6 @@ Optionally, you can use the following command-line flags:
| `--pre_layer PRE_LAYER [PRE_LAYER ...]` | The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. For multi-gpu, write the numbers separated by spaces, eg `--pre_layer 30 60`. | | `--pre_layer PRE_LAYER [PRE_LAYER ...]` | The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. For multi-gpu, write the numbers separated by spaces, eg `--pre_layer 30 60`. |
| `--checkpoint CHECKPOINT` | The path to the quantized checkpoint file. If not specified, it will be automatically detected. | | `--checkpoint CHECKPOINT` | The path to the quantized checkpoint file. If not specified, it will be automatically detected. |
| `--monkey-patch` | Apply the monkey patch for using LoRAs with quantized models. | `--monkey-patch` | Apply the monkey patch for using LoRAs with quantized models.
| `--quant_attn` | (triton) Enable quant attention. |
| `--warmup_autotune` | (triton) Enable warmup autotune. |
| `--fused_mlp` | (triton) Enable fused mlp. |
#### FlexGen
| Flag | Description |
|------------------|-------------|
| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
| `--compress-weight` | FlexGen: Whether to compress weight (default: False).|
| `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). |
#### DeepSpeed #### DeepSpeed
@ -312,6 +338,14 @@ Optionally, you can use the following command-line flags:
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". | | `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. | | `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
#### RoPE (for llama.cpp, ExLlama, ExLlamaV2, and transformers)
| Flag | Description |
|------------------|-------------|
| `--alpha_value ALPHA_VALUE` | Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both. |
| `--rope_freq_base ROPE_FREQ_BASE` | If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63). |
| `--compress_pos_emb COMPRESS_POS_EMB` | Positional embeddings compression factor. Should be set to (context length) / (model's original context length). Equal to 1/rope_freq_scale. |
#### Gradio #### Gradio
| Flag | Description | | Flag | Description |
@ -323,6 +357,8 @@ Optionally, you can use the following command-line flags:
| `--auto-launch` | Open the web UI in the default browser upon launch. | | `--auto-launch` | Open the web UI in the default browser upon launch. |
| `--gradio-auth USER:PWD` | set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3" | | `--gradio-auth USER:PWD` | set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3" |
| `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" | | `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" |
| `--ssl-keyfile SSL_KEYFILE` | The path to the SSL certificate key file. |
| `--ssl-certfile SSL_CERTFILE` | The path to the SSL certificate cert file. |
#### API #### API
@ -330,6 +366,7 @@ Optionally, you can use the following command-line flags:
|---------------------------------------|-------------| |---------------------------------------|-------------|
| `--api` | Enable the API extension. | | `--api` | Enable the API extension. |
| `--public-api` | Create a public URL for the API using Cloudfare. | | `--public-api` | Create a public URL for the API using Cloudfare. |
| `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. |
| `--api-blocking-port BLOCKING_PORT` | The listening port for the blocking API. | | `--api-blocking-port BLOCKING_PORT` | The listening port for the blocking API. |
| `--api-streaming-port STREAMING_PORT` | The listening port for the streaming API. | | `--api-streaming-port STREAMING_PORT` | The listening port for the streaming API. |
@ -339,8 +376,6 @@ Optionally, you can use the following command-line flags:
|---------------------------------------|-------------| |---------------------------------------|-------------|
| `--multimodal-pipeline PIPELINE` | The multimodal pipeline to use. Examples: `llava-7b`, `llava-13b`. | | `--multimodal-pipeline PIPELINE` | The multimodal pipeline to use. Examples: `llava-7b`, `llava-13b`. |
Out of memory errors? [Check the low VRAM guide](docs/Low-VRAM-guide.md).
## Presets ## Presets
Inference settings presets can be created under `presets/` as yaml files. These files are detected automatically at startup. Inference settings presets can be created under `presets/` as yaml files. These files are detected automatically at startup.
@ -349,18 +384,13 @@ The presets that are included by default are the result of a contest that receiv
## Contributing ## Contributing
* Pull requests, suggestions, and issue reports are welcome. If you would like to contribute to the project, check out the [Contributing guidelines](https://github.com/oobabooga/text-generation-webui/wiki/Contributing-guidelines).
* Make sure to carefully [search](https://github.com/oobabooga/text-generation-webui/issues) existing issues before starting a new one.
* If you have some experience with git, testing an open pull request and leaving a comment on whether it works as expected or not is immensely helpful.
* A simple way to contribute, even if you are not a programmer, is to leave a 👍 on an issue or pull request that you find relevant.
## Community ## Community
* Subreddit: https://www.reddit.com/r/oobaboogazz/ * Subreddit: https://www.reddit.com/r/oobabooga/
* Discord: https://discord.gg/jwZCF2dPQN * Discord: https://discord.gg/jwZCF2dPQN
## Credits ## Acknowledgment
- Gradio dropdown menu refresh button, code for reloading the interface: https://github.com/AUTOMATIC1111/stable-diffusion-webui In August 2023, [Andreessen Horowitz](https://a16z.com/) (a16z) provided a generous grant to encourage and support my independent work on this project. I am **extremely** grateful for their trust and recognition, which will allow me to dedicate more time towards realizing the full potential of text-generation-webui.
- Godlike preset: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
- Code for some of the sliders: https://github.com/PygmalionAI/gradio-ui/

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
import html
import json import json
import sys import sys
@ -20,17 +21,24 @@ async def run(user_input, history):
request = { request = {
'user_input': user_input, 'user_input': user_input,
'max_new_tokens': 250, 'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'history': history, 'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example', 'character': 'Example',
'instruction_template': 'Vicuna-v1.1', 'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset
'your_name': 'You', 'your_name': 'You',
# 'name1': 'name of user', # Optional
# 'name2': 'name of character', # Optional
# 'context': 'character context', # Optional
# 'greeting': 'greeting', # Optional
# 'name1_instruct': 'You', # Optional
# 'name2_instruct': 'Assistant', # Optional
# 'context_instruct': 'context_instruct', # Optional
# 'turn_template': 'turn_template', # Optional
'regenerate': False, 'regenerate': False,
'_continue': False, '_continue': False,
'stop_at_newline': False, 'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'chat_generation_attempts': 1,
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
# Generation params. If 'preset' is set to different than 'None', the values # Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers. # in presets/preset-name.yaml are used instead of the individual numbers.
@ -55,11 +63,14 @@ async def run(user_input, history):
'mirostat_mode': 0, 'mirostat_mode': 0,
'mirostat_tau': 5, 'mirostat_tau': 5,
'mirostat_eta': 0.1, 'mirostat_eta': 0.1,
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1, 'seed': -1,
'add_bos_token': True, 'add_bos_token': True,
'truncation_length': 2048, 'truncation_length': 2048,
'ban_eos_token': False, 'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True, 'skip_special_tokens': True,
'stopping_strings': [] 'stopping_strings': []
} }
@ -83,7 +94,7 @@ async def print_response_stream(user_input, history):
async for new_history in run(user_input, history): async for new_history in run(user_input, history):
cur_message = new_history['visible'][-1][1][cur_len:] cur_message = new_history['visible'][-1][1][cur_len:]
cur_len += len(cur_message) cur_len += len(cur_message)
print(cur_message, end='') print(html.unescape(cur_message), end='')
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime. sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.

View file

@ -1,3 +1,4 @@
import html
import json import json
import requests import requests
@ -14,17 +15,24 @@ def run(user_input, history):
request = { request = {
'user_input': user_input, 'user_input': user_input,
'max_new_tokens': 250, 'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'history': history, 'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example', 'character': 'Example',
'instruction_template': 'Vicuna-v1.1', 'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset
'your_name': 'You', 'your_name': 'You',
# 'name1': 'name of user', # Optional
# 'name2': 'name of character', # Optional
# 'context': 'character context', # Optional
# 'greeting': 'greeting', # Optional
# 'name1_instruct': 'You', # Optional
# 'name2_instruct': 'Assistant', # Optional
# 'context_instruct': 'context_instruct', # Optional
# 'turn_template': 'turn_template', # Optional
'regenerate': False, 'regenerate': False,
'_continue': False, '_continue': False,
'stop_at_newline': False, 'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'chat_generation_attempts': 1,
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
# Generation params. If 'preset' is set to different than 'None', the values # Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers. # in presets/preset-name.yaml are used instead of the individual numbers.
@ -49,11 +57,14 @@ def run(user_input, history):
'mirostat_mode': 0, 'mirostat_mode': 0,
'mirostat_tau': 5, 'mirostat_tau': 5,
'mirostat_eta': 0.1, 'mirostat_eta': 0.1,
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1, 'seed': -1,
'add_bos_token': True, 'add_bos_token': True,
'truncation_length': 2048, 'truncation_length': 2048,
'ban_eos_token': False, 'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True, 'skip_special_tokens': True,
'stopping_strings': [] 'stopping_strings': []
} }
@ -64,7 +75,7 @@ def run(user_input, history):
result = response.json()['results'][0]['history'] result = response.json()['results'][0]['history']
print(json.dumps(result, indent=4)) print(json.dumps(result, indent=4))
print() print()
print(result['visible'][-1][1]) print(html.unescape(result['visible'][-1][1]))
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -4,6 +4,7 @@ import requests
HOST = '0.0.0.0:5000' HOST = '0.0.0.0:5000'
def generate(prompt, tokens=200): def generate(prompt, tokens=200):
request = {'prompt': prompt, 'max_new_tokens': tokens} request = {'prompt': prompt, 'max_new_tokens': tokens}
response = requests.post(f'http://{HOST}/api/v1/generate', json=request) response = requests.post(f'http://{HOST}/api/v1/generate', json=request)
@ -54,7 +55,7 @@ def complex_model_load(model):
'action': 'load', 'action': 'load',
'model_name': model, 'model_name': model,
'args': { 'args': {
'gptq_for_llama': False, # Use AutoGPTQ by default, set to True for gptq-for-llama 'loader': 'AutoGPTQ',
'bf16': False, 'bf16': False,
'load_in_8bit': False, 'load_in_8bit': False,
@ -107,7 +108,7 @@ def complex_model_load(model):
req['args']['bf16'] = True # for 24GB req['args']['bf16'] = True # for 24GB
elif '13b' in model: elif '13b' in model:
req['args']['load_in_8bit'] = True # for 24GB req['args']['load_in_8bit'] = True # for 24GB
elif 'ggml' in model: elif 'gguf' in model:
# req['args']['threads'] = 16 # req['args']['threads'] = 16
if '7b' in model: if '7b' in model:
req['args']['n_gpu_layers'] = 100 req['args']['n_gpu_layers'] = 100
@ -124,7 +125,6 @@ def complex_model_load(model):
else: else:
req['args']['rwkv_strategy'] = 'cuda f16' # 24GB req['args']['rwkv_strategy'] = 'cuda f16' # 24GB
return model_api(req) return model_api(req)

View file

@ -20,6 +20,8 @@ async def run(context):
request = { request = {
'prompt': context, 'prompt': context,
'max_new_tokens': 250, 'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
# Generation params. If 'preset' is set to different than 'None', the values # Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers. # in presets/preset-name.yaml are used instead of the individual numbers.
@ -44,11 +46,14 @@ async def run(context):
'mirostat_mode': 0, 'mirostat_mode': 0,
'mirostat_tau': 5, 'mirostat_tau': 5,
'mirostat_eta': 0.1, 'mirostat_eta': 0.1,
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1, 'seed': -1,
'add_bos_token': True, 'add_bos_token': True,
'truncation_length': 2048, 'truncation_length': 2048,
'ban_eos_token': False, 'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True, 'skip_special_tokens': True,
'stopping_strings': [] 'stopping_strings': []
} }

View file

@ -12,6 +12,8 @@ def run(prompt):
request = { request = {
'prompt': prompt, 'prompt': prompt,
'max_new_tokens': 250, 'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
# Generation params. If 'preset' is set to different than 'None', the values # Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers. # in presets/preset-name.yaml are used instead of the individual numbers.
@ -36,11 +38,14 @@ def run(prompt):
'mirostat_mode': 0, 'mirostat_mode': 0,
'mirostat_tau': 5, 'mirostat_tau': 5,
'mirostat_eta': 0.1, 'mirostat_eta': 0.1,
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1, 'seed': -1,
'add_bos_token': True, 'add_bos_token': True,
'truncation_length': 2048, 'truncation_length': 2048,
'ban_eos_token': False, 'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True, 'skip_special_tokens': True,
'stopping_strings': [] 'stopping_strings': []
} }

View file

@ -0,0 +1,4 @@
name: AI
greeting: How can I help you today?
context: |
The following is a conversation with an AI Large Language Model. The AI has been trained to answer questions, provide recommendations, and help with decision making. The AI follows user requests. The AI thinks outside the box.

View file

@ -1,9 +1,10 @@
name: "Chiharu Yamada" name: Chiharu Yamada
context: "Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology."
greeting: |- greeting: |-
*Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air* *Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air*
Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started! Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started!
example_dialogue: |- context: |-
Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology.
{{user}}: So how did you get into computer engineering? {{user}}: So how did you get into computer engineering?
{{char}}: I've always loved tinkering with technology since I was a kid. {{char}}: I've always loved tinkering with technology since I was a kid.
{{user}}: That's really impressive! {{user}}: That's really impressive!

View file

@ -1,4 +0,0 @@
user: ""
bot: "### Response:"
turn_template: "<|user-message|>\n\n<|bot|><|bot-message|>\n\n</s>"
context: ""

View file

@ -1,63 +0,0 @@
'''
Converts a transformers model to a format compatible with flexgen.
'''
import argparse
import os
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
args = parser.parse_args()
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
global torch_linear_init_backup
global torch_layer_norm_init_backup
torch_linear_init_backup = torch.nn.Linear.reset_parameters
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def restore_torch_init():
"""Rollback the change made by disable_torch_init."""
import torch
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
if __name__ == '__main__':
path = Path(args.MODEL)
model_name = path.name
print(f"Loading {model_name}...")
# disable_torch_init()
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
# restore_torch_init()
tokenizer = AutoTokenizer.from_pretrained(path)
out_folder = Path(f"models/{model_name}-np")
if not Path(out_folder).exists():
os.mkdir(out_folder)
print(f"Saving the converted model to {out_folder}...")
for name, param in tqdm(list(model.model.named_parameters())):
name = name.replace("decoder.final_layer_norm", "decoder.layer_norm")
param_path = os.path.join(out_folder, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

166
css/NotoSans/stylesheet.css Normal file
View file

@ -0,0 +1,166 @@
/*
Copied from https://github.com/SillyTavern/SillyTavern/tree/6c8bd06308c69d51e2eb174541792a870a83d2d6/public/webfonts/NotoSans
*/
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-Black.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-Black.woff') format('woff');
font-weight: 900;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-ExtraBoldItalic.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-ExtraBoldItalic.woff') format('woff');
font-weight: bold;
font-style: italic;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-BlackItalic.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-BlackItalic.woff') format('woff');
font-weight: 900;
font-style: italic;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-ExtraBold.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-ExtraBold.woff') format('woff');
font-weight: bold;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-ThinItalic.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-ThinItalic.woff') format('woff');
font-weight: 100;
font-style: italic;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-BoldItalic.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-BoldItalic.woff') format('woff');
font-weight: bold;
font-style: italic;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-Bold.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-Bold.woff') format('woff');
font-weight: bold;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-LightItalic.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-LightItalic.woff') format('woff');
font-weight: 300;
font-style: italic;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-Italic.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-Italic.woff') format('woff');
font-weight: normal;
font-style: italic;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-ExtraLightItalic.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-ExtraLightItalic.woff') format('woff');
font-weight: 200;
font-style: italic;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-Light.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-Light.woff') format('woff');
font-weight: 300;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-ExtraLight.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-ExtraLight.woff') format('woff');
font-weight: 200;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-Medium.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-Medium.woff') format('woff');
font-weight: 500;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-Regular.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-Regular.woff') format('woff');
font-weight: normal;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-MediumItalic.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-MediumItalic.woff') format('woff');
font-weight: 500;
font-style: italic;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-SemiBoldItalic.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-SemiBoldItalic.woff') format('woff');
font-weight: 600;
font-style: italic;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-SemiBold.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-SemiBold.woff') format('woff');
font-weight: 600;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'Noto Sans';
src: url('file/css/NotoSans/NotoSans-Thin.woff2') format('woff2'),
url('file/css/NotoSans/NotoSans-Thin.woff') format('woff');
font-weight: 100;
font-style: normal;
font-display: swap;
}

View file

@ -1,126 +0,0 @@
.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx {
height: 66.67vh
}
.gradio-container {
margin-left: auto !important;
margin-right: auto !important;
}
.w-screen {
width: unset
}
div.svelte-362y77>*, div.svelte-362y77>.form>* {
flex-wrap: nowrap
}
/* fixes the API documentation in chat mode */
.api-docs.svelte-1iguv9h.svelte-1iguv9h.svelte-1iguv9h {
display: grid;
}
.pending.svelte-1ed2p3z {
opacity: 1;
}
#extensions {
padding: 0;
padding: 0;
}
#gradio-chatbot {
height: 66.67vh;
}
.wrap.svelte-6roggh.svelte-6roggh {
max-height: 92.5%;
}
/* This is for the microphone button in the whisper extension */
.sm.svelte-1ipelgc {
width: 100%;
}
#main button {
min-width: 0 !important;
}
/*****************************************************/
/*************** Chat box declarations ***************/
/*****************************************************/
.chat {
margin-left: auto;
margin-right: auto;
max-width: 800px;
height: calc(100vh - 296px);
overflow-y: auto;
padding-right: 20px;
display: flex;
flex-direction: column-reverse;
word-break: break-word;
overflow-wrap: anywhere;
padding-top: 1px;
}
.message-body li {
margin-top: 0.5em !important;
margin-bottom: 0.5em !important;
}
.message-body li > p {
display: inline !important;
}
.message-body ul, .message-body ol {
font-size: 15px !important;
}
.message-body ul {
list-style-type: disc !important;
}
.message-body pre {
margin-bottom: 1.25em !important;
}
.message-body code {
white-space: pre-wrap !important;
word-wrap: break-word !important;
}
.message-body :not(pre) > code {
white-space: normal !important;
}
@media print {
body {
visibility: hidden;
}
.chat {
visibility: visible;
position: absolute;
left: 0;
top: 0;
max-width: none;
max-height: none;
width: 100%;
height: fit-content;
display: flex;
flex-direction: column-reverse;
}
.message {
break-inside: avoid;
}
.gradio-container {
overflow: visible;
}
.tab-nav {
display: none !important;
}
}

View file

@ -1,4 +0,0 @@
document.getElementById("main").childNodes[0].style = "max-width: 800px; margin-left: auto; margin-right: auto";
document.getElementById("extensions").style.setProperty("max-width", "800px");
document.getElementById("extensions").style.setProperty("margin-left", "auto");
document.getElementById("extensions").style.setProperty("margin-right", "auto");

View file

@ -5,22 +5,14 @@
grid-template-columns: 60px minmax(0, 1fr); grid-template-columns: 60px minmax(0, 1fr);
padding-bottom: 28px; padding-bottom: 28px;
font-size: 18px; font-size: 18px;
/*Change 'Quicksand' to a font you like or leave it*/ font-family: 'Noto Sans', Arial, sans-serif;
font-family: Quicksand, Arial, sans-serif;
line-height: 1.428571429; line-height: 1.428571429;
} }
.circle-you { .circle-you,
background-color: gray;
border-radius: 1rem;
/*Change color to any you like to be the border of your image*/
border: 2px solid white;
}
.circle-bot { .circle-bot {
background-color: gray; background-color: gray;
border-radius: 1rem; border-radius: 1rem;
/*Change color to any you like to be the border of the bot's image*/
border: 2px solid white; border: 2px solid white;
} }
@ -41,7 +33,7 @@
.text { .text {
/*Change this to move the message box further left or right depending on the size of your profile pic*/ /*Change this to move the message box further left or right depending on the size of your profile pic*/
padding-left: 90px; padding-left: 90px;
text-shadow: 2px 2px 2px rgb(0, 0, 0); text-shadow: 2px 2px 2px rgb(0, 0, 0, 0.4);
} }
.text p { .text p {
@ -96,12 +88,46 @@
margin-bottom: 0 !important; margin-bottom: 0 !important;
font-size: 18px !important; font-size: 18px !important;
line-height: 1.428571429 !important; line-height: 1.428571429 !important;
} color: rgb(243, 244, 246) !important;
text-shadow: 2px 2px 2px rgb(0, 0, 0);
.dark .message-body p em {
color: rgb(138, 138, 138) !important;
} }
.message-body p em { .message-body p em {
color: rgb(110, 110, 110) !important; color: rgb(138, 138, 138) !important;
}
@media screen and (max-width: 688px) {
.message {
display: grid;
grid-template-columns: 60px minmax(0, 1fr);
padding-bottom: 25px;
font-size: 15px;
font-family: 'Noto Sans', Helvetica, Arial, sans-serif;
line-height: 1.428571429;
}
.circle-you, .circle-bot {
width: 50px;
height: 73px;
border-radius: 0.5rem;
}
.circle-bot img,
.circle-you img {
width: 100%;
height: 100%;
object-fit: cover;
}
.text {
padding-left: 0px;
}
.message-body p {
font-size: 16px !important;
}
.username {
font-size: 20px;
}
} }

View file

@ -0,0 +1,21 @@
@import url("file/css/chat_style-cai-chat.css");
.circle-bot, .circle-you {
height: 90px;
width: 60px;
border-radius: 10px;
background-color: #656565;
}
.circle-bot img, .circle-you img {
border-radius: 8.333px;
}
.circle-you {
background-color: #656565;
}
.message {
padding-bottom: 30px;
grid-template-columns: 70px minmax(0, 1fr);
}

View file

@ -3,8 +3,8 @@
grid-template-columns: 60px minmax(0, 1fr); grid-template-columns: 60px minmax(0, 1fr);
padding-bottom: 25px; padding-bottom: 25px;
font-size: 15px; font-size: 15px;
font-family: Helvetica, Arial, sans-serif; font-family: 'Noto Sans', Helvetica, Arial, sans-serif;
line-height: 1.428571429; line-height: 23px !important;
} }
.circle-you { .circle-you {
@ -46,7 +46,7 @@
.message-body p { .message-body p {
margin-bottom: 0 !important; margin-bottom: 0 !important;
font-size: 15px !important; font-size: 15px !important;
line-height: 1.428571429 !important; line-height: 23px !important;
} }
.dark .message-body p em { .dark .message-body p em {
@ -55,4 +55,5 @@
.message-body p em { .message-body p em {
color: rgb(110, 110, 110) !important; color: rgb(110, 110, 110) !important;
font-weight: 500;
} }

View file

@ -1,7 +1,7 @@
.message { .message {
padding-bottom: 25px; padding-bottom: 25px;
font-size: 15px; font-size: 15px;
font-family: Helvetica, Arial, sans-serif; font-family: 'Noto Sans', Helvetica, Arial, sans-serif;
line-height: 1.428571429; line-height: 1.428571429;
} }

View file

@ -1,7 +1,7 @@
.message { .message {
padding-bottom: 25px; padding-bottom: 25px;
font-size: 15px; font-size: 15px;
font-family: Helvetica, Arial, sans-serif; font-family: 'Noto Sans', Helvetica, Arial, sans-serif;
line-height: 1.428571429; line-height: 1.428571429;
} }

View file

@ -98,7 +98,7 @@
margin-right: 40px !important; margin-right: 40px !important;
} }
#parent #container .message { #parent #container .message_4chan {
color: black; color: black;
border: none; border: none;
} }

View file

@ -3,8 +3,8 @@
grid-template-columns: 60px 1fr; grid-template-columns: 60px 1fr;
padding-bottom: 25px; padding-bottom: 25px;
font-size: 15px; font-size: 15px;
font-family: Helvetica, Arial, sans-serif; font-family: 'Noto Sans', Helvetica, Arial, sans-serif;
line-height: 1.428571429; line-height: 22px;
} }
.username { .username {
@ -13,11 +13,11 @@
.message-body p { .message-body p {
font-size: 15px !important; font-size: 15px !important;
line-height: 1.75 !important; line-height: 22px !important;
margin-bottom: 1.25em !important; margin-bottom: 1.25em !important;
} }
.message-body ul, .message-body ol { .chat .message-body ul, .chat .message-body ol {
margin-bottom: 1.25em !important; margin-bottom: 1.25em !important;
} }
@ -43,14 +43,16 @@
margin-bottom: 9px !important; margin-bottom: 9px !important;
} }
.gradio-container .chat .assistant-message:last-child, .gradio-container .chat .user-message:last-child {
margin-bottom: 0px !important;
}
.dark .chat .assistant-message { .dark .chat .assistant-message {
background-color: #3741519e; background-color: #1f2937;
border: 1px solid #4b5563;
} }
.dark .chat .user-message { .dark .chat .user-message {
background-color: #111827; background-color: transparent;
border: 1px solid #4b5563;
} }
code { code {
@ -58,5 +60,5 @@ code {
} }
.dark code { .dark code {
background-color: #1a212f !important; background-color: #0e1321 !important;
} }

View file

@ -27,3 +27,7 @@
.container :not(pre) > code { .container :not(pre) > code {
white-space: normal !important; white-space: normal !important;
} }
.container .hoverable {
font-size: 14px;
}

View file

@ -7,6 +7,7 @@
} }
.small-button { .small-button {
min-width: 0 !important;
max-width: 171px; max-width: 171px;
height: 39.594px; height: 39.594px;
align-self: end; align-self: end;
@ -26,6 +27,10 @@
max-width: 2.2em; max-width: 2.2em;
} }
.button_nowrap {
white-space: nowrap;
}
#slim-column { #slim-column {
flex: none !important; flex: none !important;
min-width: 0 !important; min-width: 0 !important;
@ -41,9 +46,6 @@
min-height: 0 min-height: 0
} }
#accordion {
}
.dark svg { .dark svg {
fill: white; fill: white;
} }
@ -56,7 +58,7 @@ ol li p, ul li p {
display: inline-block; display: inline-block;
} }
#main, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab { #chat-tab, #default-tab, #notebook-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab {
border: 0; border: 0;
} }
@ -70,7 +72,7 @@ ol li p, ul li p {
} }
#extensions { #extensions {
padding: 15px; margin-top: 5px;
margin-bottom: 35px; margin-bottom: 35px;
} }
@ -89,7 +91,11 @@ div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
.header_bar { .header_bar {
background-color: #f7f7f7; background-color: #f7f7f7;
margin-bottom: 30px; margin-bottom: 19px;
display: inline !important;
overflow-x: scroll;
margin-left: calc(-1 * var(--size-4));
margin-right: calc(-1 * var(--size-4));
} }
.dark .header_bar { .dark .header_bar {
@ -97,19 +103,39 @@ div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
background-color: #8080802b; background-color: #8080802b;
} }
.header_bar button.selected {
border-radius: 0;
}
.textbox_default textarea { .textbox_default textarea {
height: calc(100vh - 390px); height: calc(100dvh - 271px);
} }
.textbox_default_output textarea { .textbox_default_output textarea {
height: calc(100vh - 200px); height: calc(100dvh - 185px);
} }
.textbox textarea { .textbox textarea {
height: calc(100vh - 251px); height: calc(100dvh - 241px);
} }
.textbox_default textarea, .textbox_default_output textarea, .textbox textarea { .textbox_logits textarea {
height: calc(100dvh - 236px);
}
.textbox_logits_notebook textarea {
height: calc(100dvh - 292px);
}
.monospace {
font-family: monospace;
}
.textbox_default textarea,
.textbox_default_output textarea,
.textbox_logits textarea,
.textbox_logits_notebook textarea,
.textbox textarea {
font-size: 16px !important; font-size: 16px !important;
color: #46464A !important; color: #46464A !important;
} }
@ -118,6 +144,24 @@ div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
color: #efefef !important; color: #efefef !important;
} }
@media screen and (max-width: 711px) {
.textbox_default textarea {
height: calc(100dvh - 259px);
}
div .default-token-counter {
top: calc( 0.5 * (100dvh - 236px) ) !important;
}
.transparent-substring {
display: none;
}
.hover-menu {
min-width: 250px !important;
}
}
/* Hide the gradio footer*/ /* Hide the gradio footer*/
footer { footer {
display: none !important; display: none !important;
@ -155,3 +199,405 @@ button {
.markdown ul ol { .markdown ul ol {
font-size: 100% !important; font-size: 100% !important;
} }
.pretty_scrollbar::-webkit-scrollbar {
width: 5px;
}
.pretty_scrollbar::-webkit-scrollbar-track {
background: transparent;
}
.pretty_scrollbar::-webkit-scrollbar-thumb,
.pretty_scrollbar::-webkit-scrollbar-thumb:hover {
background: #c5c5d2;
}
.dark .pretty_scrollbar::-webkit-scrollbar-thumb,
.dark .pretty_scrollbar::-webkit-scrollbar-thumb:hover {
background: #374151;
}
.pretty_scrollbar::-webkit-resizer {
background: #c5c5d2;
}
.dark .pretty_scrollbar::-webkit-resizer {
background: #374151;
}
audio {
max-width: 100%;
}
/* Copied from https://github.com/AUTOMATIC1111/stable-diffusion-webui */
.token-counter {
position: absolute !important;
top: calc( 0.5 * (100dvh - 218px) ) !important;
right: 2px;
z-index: 100;
background: var(--input-background-fill) !important;
min-height: 0 !important;
}
.default-token-counter {
top: calc( 0.5 * (100dvh - 248px) ) !important;
}
.token-counter span {
padding: 1px;
box-shadow: 0 0 0 0.3em rgba(192,192,192,0.15), inset 0 0 0.6em rgba(192,192,192,0.075);
border: 2px solid rgba(192,192,192,0.4) !important;
border-radius: 0.4em;
}
.no-background {
background: var(--background-fill-primary) !important;
padding: 0px !important;
}
/*****************************************************/
/*************** Chat UI declarations ****************/
/*****************************************************/
.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx {
height: 66.67vh
}
.gradio-container {
margin-left: auto !important;
margin-right: auto !important;
}
.w-screen {
width: unset
}
div.svelte-362y77>*, div.svelte-362y77>.form>* {
flex-wrap: nowrap
}
.pending.svelte-1ed2p3z {
opacity: 1;
}
.wrap.svelte-6roggh.svelte-6roggh {
max-height: 92.5%;
}
/* This is for the microphone button in the whisper extension */
.sm.svelte-1ipelgc {
width: 100%;
}
#chat-tab button#Generate, #chat-tab button#stop {
width: 89.3438px !important;
}
#chat-tab button, #notebook-tab button, #default-tab button {
min-width: 0 !important;
}
#chat-tab > :first-child, #extensions {
max-width: 880px;
margin-left: auto;
margin-right: auto;
}
@media screen and (max-width: 688px) {
#chat-tab {
padding-left: 0px;
padding-right: 0px;
}
.chat-parent {
height: calc(100dvh - 179px) !important;
}
.old-ui .chat-parent {
height: calc(100dvh - 310px) !important;
}
}
.chat {
margin-left: auto;
margin-right: auto;
max-width: 880px;
height: 100%;
overflow-y: auto;
padding-right: 15px;
display: flex;
flex-direction: column;
word-break: break-word;
overflow-wrap: anywhere;
}
.chat-parent {
height: calc(100dvh - 181px);
overflow: auto !important;
}
.old-ui .chat-parent {
height: calc(100dvh - 270px);
}
.chat-parent.bigchat {
height: calc(100dvh - 181px) !important;
}
.chat > .messages {
display: flex;
flex-direction: column;
}
.chat .message:last-child {
margin-bottom: 0px !important;
padding-bottom: 0px !important;
}
.message-body li {
margin-top: 0 !important;
margin-bottom: 0 !important;
}
.message-body li > p {
display: inline !important;
}
.message-body ul, .message-body ol {
font-size: 15px !important;
}
.message-body ul {
list-style-type: disc !important;
}
.message-body pre {
margin-bottom: 1.25em !important;
}
.message-body code {
white-space: pre-wrap !important;
word-wrap: break-word !important;
}
.message-body :not(pre) > code {
white-space: normal !important;
}
#chat-input {
padding: 0;
padding-top: 18px;
background: transparent;
border: none;
}
#chat-input textarea:focus {
box-shadow: none !important;
}
@media print {
body {
visibility: hidden;
}
.chat {
visibility: visible;
position: absolute;
left: 0;
top: 0;
max-width: unset;
max-height: unset;
width: 100%;
overflow-y: visible;
}
.message {
break-inside: avoid;
}
.gradio-container {
overflow: visible;
}
.tab-nav {
display: none !important;
}
#chat-tab > :first-child {
max-width: unset;
}
}
#show-controls {
position: absolute;
height: 100%;
background-color: var(--background-fill-primary);
border: 0px;
border-radius: 0px;
}
#show-controls label {
z-index: 1000;
position: absolute;
left: calc(100% - 168px);
}
#typing-container {
display: none;
position: absolute;
background-color: transparent;
left: -2px;
padding: var(--block-padding);
}
.typing {
position: relative;
}
.visible-dots #typing-container {
display: block;
}
.typing span {
content: '';
animation: blink 1.5s infinite;
animation-fill-mode: both;
height: 10px;
width: 10px;
background: #3b5998;;
position: absolute;
left:0;
top:0;
border-radius: 50%;
}
.typing .dot1 {
animation-delay: .2s;
margin-left: calc(10px * 1.5);
}
.typing .dot2 {
animation-delay: .4s;
margin-left: calc(10px * 3);
}
@keyframes blink {
0% {
opacity: .1;
}
20% {
opacity: 1;
}
100% {
opacity: .1;
}
}
#chat-tab .generating {
display: none !important;
}
.hover-element {
position: relative;
font-size: 24px;
}
.hover-menu {
display: none;
position: absolute;
bottom: 80%;
left: 0;
background-color: var(--background-fill-secondary);
box-shadow: 0 0 10px rgba(0, 0, 0, 0.5);
z-index: 10000;
min-width: 330px;
flex-direction: column;
}
.hover-menu button {
width: 100%;
background: transparent !important;
border-radius: 0px !important;
justify-content: space-between;
margin: 0 !important;
height: 36px;
}
.hover-menu button:not(#clear-history-confirm) {
border-bottom: 0 !important;
}
.hover-menu button:not(#clear-history-confirm):last-child {
border-bottom: var(--button-border-width) solid var(--button-secondary-border-color) !important;
}
.hover-menu button:hover {
background: var(--button-secondary-background-fill-hover) !important;
}
.transparent-substring {
opacity: 0.333;
}
#chat-tab:not(.old-ui) #chat-buttons {
display: none !important;
}
#gr-hover-container {
min-width: 0 !important;
display: flex;
flex-direction: column-reverse;
padding-right: 20px;
padding-bottom: 3px;
flex-grow: 0 !important;
}
#generate-stop-container {
min-width: 0 !important;
display: flex;
flex-direction: column-reverse;
padding-bottom: 3px;
flex: 0 auto !important;
}
#chat-input-container {
min-width: 0 !important;
}
#chat-input-container > .form {
background: transparent;
border: none;
}
#chat-input-row {
padding-bottom: 20px;
}
.old-ui #chat-input-row, #chat-input-row.bigchat {
padding-bottom: 0px !important;
}
#chat-col {
padding-bottom: 115px;
}
.old-ui #chat-col, #chat-col.bigchat {
padding-bottom: 95px !important;
}
.old-ui #chat-buttons #clear-history-confirm {
order: -1;
}
.chat ol, .chat ul {
margin-top: 6px !important;
}
#past-chats-row {
margin-bottom: calc( -1 * var(--layout-gap) );
}
#rename-row label {
margin-top: var(--layout-gap);
}

View file

@ -1,18 +0,0 @@
document.getElementById("main").parentNode.childNodes[0].classList.add("header_bar");
document.getElementById("main").parentNode.style = "padding: 0; margin: 0";
document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0";
// Get references to the elements
let main = document.getElementById('main');
let main_parent = main.parentNode;
let extensions = document.getElementById('extensions');
// Add an event listener to the main element
main_parent.addEventListener('click', function(e) {
// Check if the main element is visible
if (main.offsetHeight > 0 && main.offsetWidth > 0) {
extensions.style.display = 'flex';
} else {
extensions.style.display = 'none';
}
});

View file

@ -1,22 +1,23 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as builder FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as builder
RUN apt-get update && \ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,rw apt-get update && \
apt-get install --no-install-recommends -y git vim build-essential python3-dev python3-venv && \ apt-get install --no-install-recommends -y git vim build-essential python3-dev python3-venv && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
RUN git clone https://github.com/oobabooga/GPTQ-for-LLaMa /build RUN git clone --depth=1 https://github.com/oobabooga/GPTQ-for-LLaMa /build
WORKDIR /build WORKDIR /build
RUN python3 -m venv /build/venv RUN --mount=type=cache,target=/root/.cache/pip,rw \
RUN . /build/venv/bin/activate && \ python3 -m venv /build/venv && \
. /build/venv/bin/activate && \
pip3 install --upgrade pip setuptools wheel && \ pip3 install --upgrade pip setuptools wheel && \
pip3 install torch torchvision torchaudio && \ pip3 install torch torchvision torchaudio && \
pip3 install -r requirements.txt pip3 install -r requirements.txt
# https://developer.nvidia.com/cuda-gpus # https://developer.nvidia.com/cuda-gpus
# for a rtx 2060: ARG TORCH_CUDA_ARCH_LIST="7.5" # for a rtx 2060: ARG TORCH_CUDA_ARCH_LIST="7.5"
ARG TORCH_CUDA_ARCH_LIST="3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX" ARG TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX}"
RUN . /build/venv/bin/activate && \ RUN . /build/venv/bin/activate && \
python3 setup_cuda.py bdist_wheel -d . python3 setup_cuda.py bdist_wheel -d .
@ -25,11 +26,11 @@ FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
LABEL maintainer="Your Name <your.email@example.com>" LABEL maintainer="Your Name <your.email@example.com>"
LABEL description="Docker image for GPTQ-for-LLaMa and Text Generation WebUI" LABEL description="Docker image for GPTQ-for-LLaMa and Text Generation WebUI"
RUN apt-get update && \ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,rw apt-get update && \
apt-get install --no-install-recommends -y python3-dev libportaudio2 libasound-dev git python3 python3-pip make g++ && \ apt-get install --no-install-recommends -y python3-dev libportaudio2 libasound-dev git python3 python3-pip make g++ ffmpeg && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
RUN --mount=type=cache,target=/root/.cache/pip pip3 install virtualenv RUN --mount=type=cache,target=/root/.cache/pip,rw pip3 install virtualenv
RUN mkdir /app RUN mkdir /app
WORKDIR /app WORKDIR /app
@ -37,32 +38,38 @@ WORKDIR /app
ARG WEBUI_VERSION ARG WEBUI_VERSION
RUN test -n "${WEBUI_VERSION}" && git reset --hard ${WEBUI_VERSION} || echo "Using provided webui source" RUN test -n "${WEBUI_VERSION}" && git reset --hard ${WEBUI_VERSION} || echo "Using provided webui source"
# Create virtualenv
RUN virtualenv /app/venv RUN virtualenv /app/venv
RUN . /app/venv/bin/activate && \ RUN --mount=type=cache,target=/root/.cache/pip,rw \
. /app/venv/bin/activate && \
pip3 install --upgrade pip setuptools wheel && \ pip3 install --upgrade pip setuptools wheel && \
pip3 install torch torchvision torchaudio pip3 install torch torchvision torchaudio sentence_transformers xformers
# Copy and install GPTQ-for-LLaMa
COPY --from=builder /build /app/repositories/GPTQ-for-LLaMa COPY --from=builder /build /app/repositories/GPTQ-for-LLaMa
RUN . /app/venv/bin/activate && \ RUN --mount=type=cache,target=/root/.cache/pip,rw \
. /app/venv/bin/activate && \
pip3 install /app/repositories/GPTQ-for-LLaMa/*.whl pip3 install /app/repositories/GPTQ-for-LLaMa/*.whl
COPY extensions/api/requirements.txt /app/extensions/api/requirements.txt # Install main requirements
COPY extensions/elevenlabs_tts/requirements.txt /app/extensions/elevenlabs_tts/requirements.txt
COPY extensions/google_translate/requirements.txt /app/extensions/google_translate/requirements.txt
COPY extensions/silero_tts/requirements.txt /app/extensions/silero_tts/requirements.txt
COPY extensions/whisper_stt/requirements.txt /app/extensions/whisper_stt/requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/api && pip3 install -r requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/elevenlabs_tts && pip3 install -r requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/google_translate && pip3 install -r requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/silero_tts && pip3 install -r requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/whisper_stt && pip3 install -r requirements.txt
COPY requirements.txt /app/requirements.txt COPY requirements.txt /app/requirements.txt
RUN . /app/venv/bin/activate && \ RUN --mount=type=cache,target=/root/.cache/pip,rw \
. /app/venv/bin/activate && \
pip3 install -r requirements.txt pip3 install -r requirements.txt
COPY . /app/
RUN cp /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so RUN cp /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so
COPY . /app/ # Install extension requirements
RUN --mount=type=cache,target=/root/.cache/pip,rw \
. /app/venv/bin/activate && \
for ext in /app/extensions/*/requirements.txt; do \
cd "$(dirname "$ext")"; \
pip3 install -r requirements.txt; \
done
ENV CLI_ARGS="" ENV CLI_ARGS=""
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} ${CONTAINER_API_STREAM_PORT:-5005}
CMD . /app/venv/bin/activate && python3 server.py ${CLI_ARGS} CMD . /app/venv/bin/activate && python3 server.py ${CLI_ARGS}

View file

@ -5,13 +5,13 @@ services:
context: . context: .
args: args:
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus # specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST} TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-7.5}
WEBUI_VERSION: ${WEBUI_VERSION} WEBUI_VERSION: ${WEBUI_VERSION:-HEAD}
env_file: .env env_file: .env
ports: ports:
- "${HOST_PORT}:${CONTAINER_PORT}" - "${HOST_PORT:-7860}:${CONTAINER_PORT:-7860}"
- "${HOST_API_PORT}:${CONTAINER_API_PORT}" - "${HOST_API_PORT:-5000}:${CONTAINER_API_PORT:-5000}"
- "${HOST_API_STREAM_PORT}:${CONTAINER_API_STREAM_PORT}" - "${HOST_API_STREAM_PORT:-5005}:${CONTAINER_API_STREAM_PORT:-5005}"
stdin_open: true stdin_open: true
tty: true tty: true
volumes: volumes:
@ -23,6 +23,7 @@ services:
- ./prompts:/app/prompts - ./prompts:/app/prompts
- ./softprompts:/app/softprompts - ./softprompts:/app/softprompts
- ./training:/app/training - ./training:/app/training
- ./cloudflared:/etc/cloudflared
deploy: deploy:
resources: resources:
reservations: reservations:

View file

@ -1,36 +1,30 @@
## Chat characters ## Chat characters
Custom chat mode characters are defined by `.yaml` files inside the `characters` folder. An example is included: [Example.yaml](https://github.com/oobabooga/text-generation-webui/blob/main/characters/Example.yaml) Custom chat mode characters are defined by `.yaml` files inside the `characters` folder. An example is included: [Example.yaml](https://github.com/oobabooga/text-generation-webui/blob/main/characters/Example.yaml).
The following fields may be defined: The following fields may be defined:
| Field | Description | | Field | Description |
|-------|-------------| |-------|-------------|
| `name` or `bot` | The character's name. | | `name` or `bot` | The character's name. |
| `context` | A string that appears at the top of the prompt. It usually contains a description of the character's personality and a few example messages. |
| `greeting` (optional) | The character's opening message. It appears when the character is first loaded or when the history is cleared. |
| `your_name` or `user` (optional) | Your name. This overwrites what you had previously written in the `Your name` field in the interface. | | `your_name` or `user` (optional) | Your name. This overwrites what you had previously written in the `Your name` field in the interface. |
| `context` | A string that appears at the top of the prompt. It usually contains a description of the character's personality. |
| `greeting` (optional) | The character's opening message when a new conversation is started. |
| `example_dialogue` (optional) | A few example messages to guide the model. |
| `turn_template` (optional) | Used to define where the spaces and new line characters should be in Instruct mode. See the characters in `characters/instruction-following` for examples. |
#### Special tokens #### Special tokens
* `{{char}}` or `<BOT>`: are replaced with the character's name The following replacements happen when the prompt is generated, and they apply to the `context` and `greeting` fields:
* `{{user}}` or `<USER>`: are replaced with your name
These replacements happen when the character is loaded, and they apply to the `context`, `greeting`, and `example_dialogue` fields. * `{{char}}` and `<BOT>` get replaced with the character's name.
* `{{user}}` and `<USER>` get replaced with your name.
#### How do I add a profile picture for my character? #### How do I add a profile picture for my character?
Put an image with the same name as your character's yaml file into the `characters` folder. For example, if your bot is `Character.yaml`, add `Character.jpg` or `Character.png` to the folder. Put an image with the same name as your character's `.yaml` file into the `characters` folder. For example, if your bot is `Character.yaml`, add `Character.jpg` or `Character.png` to the folder.
#### Is the chat history truncated in the prompt? #### Is the chat history truncated in the prompt?
Once your prompt reaches the 2048 token limit, old messages will be removed one at a time. The context string will always stay at the top of the prompt and will never get truncated. Once your prompt reaches the `truncation_length` parameter (2048 by default), old messages will be removed one at a time. The context string will always stay at the top of the prompt and will never get truncated.
#### Pygmalion format characters
These are also supported out of the box. Simply put the JSON file in the `characters` folder, or upload it directly from the web UI by clicking on the "Upload character" tab at the bottom.
## Chat styles ## Chat styles

View file

@ -1,45 +1,47 @@
Extensions are defined by files named `script.py` inside subfolders of `text-generation-webui/extensions`. They are loaded at startup if specified with the `--extensions` flag. # Extensions
Extensions are defined by files named `script.py` inside subfolders of `text-generation-webui/extensions`. They are loaded at startup if the folder name is specified after the `--extensions` flag.
For instance, `extensions/silero_tts/script.py` gets loaded with `python server.py --extensions silero_tts`. For instance, `extensions/silero_tts/script.py` gets loaded with `python server.py --extensions silero_tts`.
## [text-generation-webui-extensions](https://github.com/oobabooga/text-generation-webui-extensions) ## [text-generation-webui-extensions](https://github.com/oobabooga/text-generation-webui-extensions)
The link above contains a directory of user extensions for text-generation-webui. The repository above contains a directory of user extensions.
If you create an extension, you are welcome to host it in a GitHub repository and submit it to the list above. If you create an extension, you are welcome to host it in a GitHub repository and submit a PR adding it to the list.
## Built-in extensions ## Built-in extensions
Most of these have been created by the extremely talented contributors that you can find here: [contributors](https://github.com/oobabooga/text-generation-webui/graphs/contributors?from=2022-12-18&to=&type=a).
|Extension|Description| |Extension|Description|
|---------|-----------| |---------|-----------|
|[api](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/api)| Creates an API with two endpoints, one for streaming at `/api/v1/stream` port 5005 and another for blocking at `/api/v1/generate` port 5000. This is the main API for this web UI. | |[api](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/api)| Creates an API with two endpoints, one for streaming at `/api/v1/stream` port 5005 and another for blocking at `/api/v1/generate` port 5000. This is the main API for the webui. |
|[openai](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai)| Creates an API that mimics the OpenAI API and can be used as a drop-in replacement. |
|[multimodal](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal) | Adds multimodality support (text+images). For a detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal/README.md) in the extension directory. |
|[google_translate](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/google_translate)| Automatically translates inputs and outputs using Google Translate.| |[google_translate](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/google_translate)| Automatically translates inputs and outputs using Google Translate.|
|[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that biases the bot's responses in chat mode.| |[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, responses are replaced with an audio widget. |
|[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. |
|[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, it replaces the responses with an audio widget. |
|[elevenlabs_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/elevenlabs_tts)| Text-to-speech extension using the [ElevenLabs](https://beta.elevenlabs.io/) API. You need an API key to use it. | |[elevenlabs_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/elevenlabs_tts)| Text-to-speech extension using the [ElevenLabs](https://beta.elevenlabs.io/) API. You need an API key to use it. |
|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. |
|[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. | |[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. |
|[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). | |[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). |
|[multimodal](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal) | Adds multimodality support (text+images). For a detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal/README.md) in the extension directory. | |[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that adds a hidden string at the beginning of the bot's reply in chat mode. |
|[openai](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai)| Creates an API that mimics the OpenAI API and can be used as a drop-in replacement. | |[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. |
|[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. |
|[superbooga](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superbooga)| An extension that uses ChromaDB to create an arbitrarily large pseudocontext, taking as input text files, URLs, or pasted text. Based on https://github.com/kaiokendev/superbig. | |[superbooga](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superbooga)| An extension that uses ChromaDB to create an arbitrarily large pseudocontext, taking as input text files, URLs, or pasted text. Based on https://github.com/kaiokendev/superbig. |
|[ngrok](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/ngrok)| Allows you to access the web UI remotely using the ngrok reverse tunnel service (free). It's an alternative to the built-in Gradio `--share` feature. |
|[perplexity_colors](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/perplexity_colors)| Colors each token in the output text by its associated probability, as derived from the model logits. |
## How to write an extension ## How to write an extension
script.py may define the special functions and variables below. The extensions framework is based on special functions and variables that you can define in `script.py`. The functions are the following:
#### Predefined functions
| Function | Description | | Function | Description |
|-------------|-------------| |-------------|-------------|
| `def setup()` | Is executed when the extension gets imported. |
| `def ui()` | Creates custom gradio elements when the UI is launched. | | `def ui()` | Creates custom gradio elements when the UI is launched. |
| `def custom_css()` | Returns custom CSS as a string. It is applied whenever the web UI is loaded. | | `def custom_css()` | Returns custom CSS as a string. It is applied whenever the web UI is loaded. |
| `def custom_js()` | Same as above but for javascript. | | `def custom_js()` | Same as above but for javascript. |
| `def input_modifier(string, state)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. | | `def input_modifier(string, state, is_chat=False)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. |
| `def output_modifier(string, state)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | | `def output_modifier(string, state, is_chat=False)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. |
| `def chat_input_modifier(text, visible_text, state)` | Modifies both the visible and internal inputs in chat mode. Can be used to hijack the chat input with custom content. |
| `def bot_prefix_modifier(string, state)` | Applied in chat mode to the prefix for the bot's reply. | | `def bot_prefix_modifier(string, state)` | Applied in chat mode to the prefix for the bot's reply. |
| `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. | | `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. |
| `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. | | `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. |
@ -48,9 +50,7 @@ script.py may define the special functions and variables below.
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See the `multimodal` extension for an example. | | `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See the `multimodal` extension for an example. |
| `def custom_tokenized_length(prompt)` | Used in conjunction with `tokenizer_modifier`, returns the length in tokens of `prompt`. See the `multimodal` extension for an example. | | `def custom_tokenized_length(prompt)` | Used in conjunction with `tokenizer_modifier`, returns the length in tokens of `prompt`. See the `multimodal` extension for an example. |
#### `params` dictionary Additionally, you can define a special `params` dictionary. In it, the `display_name` key is used to define the displayed name of the extension in the UI, and the `is_tab` key is used to define whether the extension should appear in a new tab. By default, extensions appear at the bottom of the "Text generation" tab.
In this dictionary, `display_name` is used to define the displayed name of the extension in the UI, and `is_tab` is used to define whether the extension should appear in a new tab. By default, extensions appear at the bottom of the "Text generation" tab.
Example: Example:
@ -61,7 +61,7 @@ params = {
} }
``` ```
Additionally, `params` may contain variables that you want to be customizable through a `settings.json` file. For instance, assuming the extension is in `extensions/google_translate`, the variable `language string` in The `params` dict may also contain variables that you want to be customizable through a `settings.yaml` file. For instance, assuming the extension is in `extensions/google_translate`, the variable `language string` in
```python ```python
params = { params = {
@ -71,32 +71,19 @@ params = {
} }
``` ```
can be customized by adding a key called `google_translate-language string` to `settings.json`: can be customized by adding a key called `google_translate-language string` to `settings.yaml`:
```python ```python
"google_translate-language string": "fr", google_translate-language string: 'fr'
``` ```
That is, the syntax is `extension_name-variable_name`. That is, the syntax for the key is `extension_name-variable_name`.
#### `input_hijack` dictionary
```python
input_hijack = {
'state': False,
'value': ["", ""]
}
```
This is only used in chat mode. If your extension sets `input_hijack['state'] = True` at any moment, the next call to `modules.chat.chatbot_wrapper` will use the values inside `input_hijack['value']` as the user input for text generation. See the `send_pictures` extension above for an example.
Additionally, your extension can set the value to be a callback in the form of `def cb(text: str, visible_text: str) -> [str, str]`. See the `multimodal` extension above for an example.
## Using multiple extensions at the same time ## Using multiple extensions at the same time
In order to use your extension, you must start the web UI with the `--extensions` flag followed by the name of your extension (the folder under `text-generation-webui/extension` where `script.py` resides). You can activate more than one extension at a time by providing their names separated by spaces after `--extensions`. The input, output, and bot prefix modifiers will be applied in the specified order.
You can activate more than one extension at a time by providing their names separated by spaces. The input, output, and bot prefix modifiers will be applied in the specified order.
Example:
``` ```
python server.py --extensions enthusiasm translate # First apply enthusiasm, then translate python server.py --extensions enthusiasm translate # First apply enthusiasm, then translate
@ -106,56 +93,152 @@ python server.py --extensions translate enthusiasm # First apply translate, then
Do note, that for: Do note, that for:
- `custom_generate_chat_prompt` - `custom_generate_chat_prompt`
- `custom_generate_reply` - `custom_generate_reply`
- `tokenizer_modifier`
- `custom_tokenized_length` - `custom_tokenized_length`
only the first declaration encountered will be used and the rest will be ignored. only the first declaration encountered will be used and the rest will be ignored.
## The `bot_prefix_modifier` ## A full example
In chat mode, this function modifies the prefix for a new bot message. For instance, if your bot is named `Marie Antoinette`, the default prefix for a new message will be The source code below can be found at [extensions/example/script.py](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/example/script.py).
```
Marie Antoinette:
```
Using `bot_prefix_modifier`, you can change it to:
```
Marie Antoinette: *I am very enthusiastic*
```
Marie Antoinette will become very enthusiastic in all her messages.
## `custom_generate_reply` example
Once defined in a `script.py`, this function is executed in place of the main generation functions. You can use it to connect the web UI to an external API, or to load a custom model that is not supported yet.
Note that in chat mode, this function must only return the new text, whereas in other modes it must return the original prompt + the new text.
```python ```python
import datetime """
An example of extension. It does nothing, but you can add transformations
before the return statements to customize the webui behavior.
def custom_generate_reply(question, original_question, seed, state, stopping_strings): Starting from history_modifier and ending in output_modifier, the
cumulative = '' functions are declared in the same order that they are called at
for i in range(10): generation time.
cumulative += f"Counting: {i}...\n" """
yield cumulative
cumulative += f"Done! {str(datetime.datetime.now())}" import gradio as gr
yield cumulative import torch
``` from transformers import LogitsProcessor
## `custom_generate_chat_prompt` example from modules import chat, shared
from modules.text_generation import (
decode,
encode,
generate_reply,
)
Below is an extension that just reproduces the default prompt generator in `modules/chat.py`. You can modify it freely to come up with your own prompts in chat mode. params = {
"display_name": "Example Extension",
"is_tab": False,
}
```python class MyLogits(LogitsProcessor):
from modules import chat """
Manipulates the probabilities for the next token before it gets sampled.
Used in the logits_processor_modifier function below.
"""
def __init__(self):
pass
def __call__(self, input_ids, scores):
# probs = torch.softmax(scores, dim=-1, dtype=torch.float)
# probs[0] /= probs[0].sum()
# scores = torch.log(probs / (1 - probs))
return scores
def history_modifier(history):
"""
Modifies the chat history.
Only used in chat mode.
"""
return history
def state_modifier(state):
"""
Modifies the state variable, which is a dictionary containing the input
values in the UI like sliders and checkboxes.
"""
return state
def chat_input_modifier(text, visible_text, state):
"""
Modifies the user input string in chat mode (visible_text).
You can also modify the internal representation of the user
input (text) to change how it will appear in the prompt.
"""
return text, visible_text
def input_modifier(string, state, is_chat=False):
"""
In default/notebook modes, modifies the whole prompt.
In chat mode, it is the same as chat_input_modifier but only applied
to "text", here called "string", and not to "visible_text".
"""
return string
def bot_prefix_modifier(string, state):
"""
Modifies the prefix for the next bot reply in chat mode.
By default, the prefix will be something like "Bot Name:".
"""
return string
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
"""
Modifies the input ids and embeds.
Used by the multimodal extension to put image embeddings in the prompt.
Only used by loaders that use the transformers library for sampling.
"""
return prompt, input_ids, input_embeds
def logits_processor_modifier(processor_list, input_ids):
"""
Adds logits processors to the list, allowing you to access and modify
the next token probabilities.
Only used by loaders that use the transformers library for sampling.
"""
processor_list.append(MyLogits())
return processor_list
def output_modifier(string, state, is_chat=False):
"""
Modifies the LLM output before it gets presented.
In chat mode, the modified version goes into history['visible'],
and the original version goes into history['internal'].
"""
return string
def custom_generate_chat_prompt(user_input, state, **kwargs): def custom_generate_chat_prompt(user_input, state, **kwargs):
"""
Replaces the function that generates the prompt from the chat history.
Only used in chat mode.
"""
result = chat.generate_chat_prompt(user_input, state, **kwargs)
return result
# Do something with kwargs['history'] or state def custom_css():
"""
Returns a CSS string that gets appended to the CSS for the webui.
"""
return ''
return chat.generate_chat_prompt(user_input, state, **kwargs) def custom_js():
"""
Returns a javascript string that gets appended to the javascript
for the webui.
"""
return ''
def setup():
"""
Gets executed only once, when the extension is imported.
"""
pass
def ui():
"""
Gets executed when the UI is drawn. Custom gradio elements and
their corresponding event handlers should be defined here.
To learn about gradio components, check out the docs:
https://gradio.app/docs/
"""
pass
``` ```

View file

@ -1,64 +0,0 @@
>FlexGen is a high-throughput generation engine for running large language models with limited GPU memory (e.g., a 16GB T4 GPU or a 24GB RTX3090 gaming card!).
https://github.com/FMInference/FlexGen
## Installation
No additional installation steps are necessary. FlexGen is in the `requirements.txt` file for this project.
## Converting a model
FlexGen only works with the OPT model, and it needs to be converted to numpy format before starting the web UI:
```
python convert-to-flexgen.py models/opt-1.3b/
```
The output will be saved to `models/opt-1.3b-np/`.
## Usage
The basic command is the following:
```
python server.py --model opt-1.3b --loader flexgen
```
For large models, the RAM usage may be too high and your computer may freeze. If that happens, you can try this:
```
python server.py --model opt-1.3b --loader flexgen --compress-weight
```
With this second command, I was able to run both OPT-6.7b and OPT-13B with **2GB VRAM**, and the speed was good in both cases.
You can also manually set the offload strategy with
```
python server.py --model opt-1.3b --loader flexgen --percent 0 100 100 0 100 0
```
where the six numbers after `--percent` are:
```
the percentage of weight on GPU
the percentage of weight on CPU
the percentage of attention cache on GPU
the percentage of attention cache on CPU
the percentage of activations on GPU
the percentage of activations on CPU
```
You should typically only change the first two numbers. If their sum is less than 100, the remaining layers will be offloaded to the disk, by default into the `text-generation-webui/cache` folder.
## Performance
In my experiments with OPT-30B using a RTX 3090 on Linux, I have obtained these results:
* `--loader flexgen --compress-weight --percent 0 100 100 0 100 0`: 0.99 seconds per token.
* `--loader flexgen --compress-weight --percent 100 0 100 0 100 0`: 0.765 seconds per token.
## Limitations
* Only works with the OPT models.
* Only two generation parameters are available: `temperature` and `do_sample`.

View file

@ -64,59 +64,19 @@ python server.py --autogptq --gpu-memory 3000MiB 6000MiB --model model_name
### Using LoRAs with AutoGPTQ ### Using LoRAs with AutoGPTQ
Not supported yet. Works fine for a single LoRA.
## GPTQ-for-LLaMa ## GPTQ-for-LLaMa
GPTQ-for-LLaMa is the original adaptation of GPTQ for the LLaMA model. It was made possible by [@qwopqwop200](https://github.com/qwopqwop200/GPTQ-for-LLaMa): https://github.com/qwopqwop200/GPTQ-for-LLaMa GPTQ-for-LLaMa is the original adaptation of GPTQ for the LLaMA model. It was made possible by [@qwopqwop200](https://github.com/qwopqwop200/GPTQ-for-LLaMa): https://github.com/qwopqwop200/GPTQ-for-LLaMa
Different branches of GPTQ-for-LLaMa are currently available, including: A Python package containing both major CUDA versions of GPTQ-for-LLaMa is used to simplify installation and compatibility: https://github.com/jllllll/GPTQ-for-LLaMa-CUDA
| Branch | Comment |
|----|----|
| [Old CUDA branch (recommended)](https://github.com/oobabooga/GPTQ-for-LLaMa/) | The fastest branch, works on Windows and Linux. |
| [Up-to-date triton branch](https://github.com/qwopqwop200/GPTQ-for-LLaMa) | Slightly more precise than the old CUDA branch from 13b upwards, significantly more precise for 7b. 2x slower for small context size and only works on Linux. |
| [Up-to-date CUDA branch](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda) | As precise as the up-to-date triton branch, 10x slower than the old cuda branch for small context size. |
Overall, I recommend using the old CUDA branch. It is included by default in the one-click-installer for this web UI.
### Installation
Start by cloning GPTQ-for-LLaMa into your `text-generation-webui/repositories` folder:
```
mkdir repositories
cd repositories
git clone https://github.com/oobabooga/GPTQ-for-LLaMa.git -b cuda
```
If you want to you to use the up-to-date CUDA or triton branches instead of the old CUDA branch, use these commands:
```
git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git -b cuda
```
```
git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git -b triton
```
Next you need to install the CUDA extensions. You can do that either by installing the precompiled wheels, or by compiling the wheels yourself.
### Precompiled wheels ### Precompiled wheels
Kindly provided by our friend jllllll: https://github.com/jllllll/GPTQ-for-LLaMa-Wheels Kindly provided by our friend jllllll: https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases
Windows: Wheels are included in requirements.txt and are installed with the webui on supported systems.
```
pip install https://github.com/jllllll/GPTQ-for-LLaMa-Wheels/raw/main/quant_cuda-0.0.0-cp310-cp310-win_amd64.whl
```
Linux:
```
pip install https://github.com/jllllll/GPTQ-for-LLaMa-Wheels/raw/Linux-x64/quant_cuda-0.0.0-cp310-cp310-linux_x86_64.whl
```
### Manual installation ### Manual installation
@ -124,30 +84,42 @@ pip install https://github.com/jllllll/GPTQ-for-LLaMa-Wheels/raw/Linux-x64/quant
``` ```
conda activate textgen conda activate textgen
conda install -c conda-forge cudatoolkit-dev conda install cuda -c nvidia/label/cuda-11.7.1
``` ```
The command above takes some 10 minutes to run and shows no progress bar or updates along the way. The command above takes some 10 minutes to run and shows no progress bar or updates along the way.
You are also going to need to have a C++ compiler installed. On Linux, `sudo apt install build-essential` or equivalent is enough. You are also going to need to have a C++ compiler installed. On Linux, `sudo apt install build-essential` or equivalent is enough. On Windows, Visual Studio or Visual Studio Build Tools is required.
If you're using an older version of CUDA toolkit (e.g. 11.7) but the latest version of `gcc` and `g++` (12.0+), you should downgrade with: `conda install -c conda-forge gxx==11.3.0`. Kernel compilation will fail otherwise. If you're using an older version of CUDA toolkit (e.g. 11.7) but the latest version of `gcc` and `g++` (12.0+) on Linux, you should downgrade with: `conda install -c conda-forge gxx==11.3.0`. Kernel compilation will fail otherwise.
#### Step 2: compile the CUDA extensions #### Step 2: compile the CUDA extensions
``` ```
cd repositories/GPTQ-for-LLaMa python -m pip install git+https://github.com/jllllll/GPTQ-for-LLaMa-CUDA -v
python setup_cuda.py install
``` ```
### Getting pre-converted LLaMA weights ### Getting pre-converted LLaMA weights
These are models that you can simply download and place in your `models` folder. * Direct download (recommended):
* Converted without `group-size` (better for the 7b model): https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483891617 https://huggingface.co/Neko-Institute-of-Science/LLaMA-7B-4bit-128g
* Converted with `group-size` (better from 13b upwards): https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483941105
⚠️ The tokenizer files in the sources above may be outdated. Make sure to obtain the universal LLaMA tokenizer as described [here](https://github.com/oobabooga/text-generation-webui/blob/main/docs/LLaMA-model.md#option-1-pre-converted-weights). https://huggingface.co/Neko-Institute-of-Science/LLaMA-13B-4bit-128g
https://huggingface.co/Neko-Institute-of-Science/LLaMA-30B-4bit-128g
https://huggingface.co/Neko-Institute-of-Science/LLaMA-65B-4bit-128g
These models were converted with `desc_act=True`. They work just fine with ExLlama. For AutoGPTQ, they will only work on Linux with the `triton` option checked.
* Torrent:
https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483891617
https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483941105
These models were converted with `desc_act=False`. As such, they are less accurate, but they work with AutoGPTQ on Windows. The `128g` versions are better from 13b upwards, and worse for 7b. The tokenizer files in the torrents are outdated, in particular the files called `tokenizer_config.json` and `special_tokens_map.json`. Here you can find those files: https://huggingface.co/oobabooga/llama-tokenizer
### Starting the web UI: ### Starting the web UI:
@ -191,22 +163,17 @@ This requires using a monkey patch that is supported by this web UI: https://git
To use it: To use it:
1. Clone `johnsmith0031/alpaca_lora_4bit` into the repositories folder: 1. Install alpaca_lora_4bit using pip
``` ```
cd text-generation-webui/repositories git clone https://github.com/johnsmith0031/alpaca_lora_4bit.git
git clone https://github.com/johnsmith0031/alpaca_lora_4bit cd alpaca_lora_4bit
git fetch origin winglian-setup_pip
git checkout winglian-setup_pip
pip install .
``` ```
⚠️ I have tested it with the following commit specifically: `2f704b93c961bf202937b10aac9322b092afdce0` 2. Start the UI with the `--monkey-patch` flag:
2. Install https://github.com/sterlind/GPTQ-for-LLaMa with this command:
```
pip install git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit
```
3. Start the UI with the `--monkey-patch` flag:
``` ```
python server.py --model llama-7b-4bit-128g --listen --lora tloen_alpaca-lora-7b --monkey-patch python server.py --model llama-7b-4bit-128g --listen --lora tloen_alpaca-lora-7b --monkey-patch

View file

@ -1,35 +0,0 @@
# Generation parameters
For a description of the generation parameters provided by the transformers library, see this link: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
### llama.cpp
llama.cpp only uses the following parameters:
* temperature
* top_p
* top_k
* repetition_penalty
* tfs
* mirostat_mode
* mirostat_tau
* mirostat_eta
### ExLlama
ExLlama only uses the following parameters:
* temperature
* top_p
* top_k
* repetition_penalty
* repetition_penalty_range
* typical_p
### RWKV
RWKV only uses the following parameters when loaded through the old .pth weights:
* temperature
* top_p
* top_k

View file

@ -9,10 +9,21 @@ This guide will cover usage through the official `transformers` implementation.
### Option 1: pre-converted weights ### Option 1: pre-converted weights
* Torrent: https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1484235789 * Direct download (recommended):
* Direct download: https://huggingface.co/Neko-Institute-of-Science
⚠️ The tokenizers for the Torrent source above and also for many LLaMA fine-tunes available on Hugging Face may be outdated, in particular the files called `tokenizer_config.json` and `special_tokens_map.json`. Here you can find those files: https://huggingface.co/oobabooga/llama-tokenizer https://huggingface.co/Neko-Institute-of-Science/LLaMA-7B-HF
https://huggingface.co/Neko-Institute-of-Science/LLaMA-13B-HF
https://huggingface.co/Neko-Institute-of-Science/LLaMA-30B-HF
https://huggingface.co/Neko-Institute-of-Science/LLaMA-65B-HF
* Torrent:
https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1484235789
The tokenizer files in the torrent above are outdated, in particular the files called `tokenizer_config.json` and `special_tokens_map.json`. Here you can find those files: https://huggingface.co/oobabooga/llama-tokenizer
### Option 2: convert the weights yourself ### Option 2: convert the weights yourself

35
docs/LLaMA-v2-model.md Normal file
View file

@ -0,0 +1,35 @@
# LLaMA-v2
To convert LLaMA-v2 from the `.pth` format provided by Meta to transformers format, follow the steps below:
1) `cd` into your `llama` folder (the one containing `download.sh` and the models that you downloaded):
```
cd llama
```
2) Clone the transformers library:
```
git clone 'https://github.com/huggingface/transformers'
```
3) Create symbolic links from the downloaded folders to names that the conversion script can recognize:
```
ln -s llama-2-7b 7B
ln -s llama-2-13b 13B
```
4) Do the conversions:
```
mkdir llama-2-7b-hf llama-2-13b-hf
python ./transformers/src/transformers/models/llama/convert_llama_weights_to_hf.py --input_dir . --model_size 7B --output_dir llama-2-7b-hf --safe_serialization true
python ./transformers/src/transformers/models/llama/convert_llama_weights_to_hf.py --input_dir . --model_size 13B --output_dir llama-2-13b-hf --safe_serialization true
```
5) Move the output folders inside `text-generation-webui/models`
6) Have fun

View file

@ -8,11 +8,9 @@
* [Docker](Docker.md) * [Docker](Docker.md)
* [ExLlama](ExLlama.md) * [ExLlama](ExLlama.md)
* [Extensions](Extensions.md) * [Extensions](Extensions.md)
* [FlexGen](FlexGen.md)
* [Generation parameters](Generation-parameters.md)
* [GPTQ models (4 bit mode)](GPTQ-models-(4-bit-mode).md) * [GPTQ models (4 bit mode)](GPTQ-models-(4-bit-mode).md)
* [llama.cpp models](llama.cpp-models.md)
* [LLaMA model](LLaMA-model.md) * [LLaMA model](LLaMA-model.md)
* [llama.cpp](llama.cpp.md)
* [LoRA](LoRA.md) * [LoRA](LoRA.md)
* [Low VRAM guide](Low-VRAM-guide.md) * [Low VRAM guide](Low-VRAM-guide.md)
* [RWKV model](RWKV-model.md) * [RWKV model](RWKV-model.md)

View file

@ -1,53 +0,0 @@
# Using llama.cpp in the web UI
## Setting up the models
#### Pre-converted
Place the model in the `models` folder, making sure that its name contains `ggml` somewhere and ends in `.bin`.
#### Convert LLaMA yourself
Follow the instructions in the llama.cpp README to generate the `ggml-model.bin` file: https://github.com/ggerganov/llama.cpp#usage
## GPU acceleration
Enabled with the `--n-gpu-layers` parameter.
* If you have enough VRAM, use a high number like `--n-gpu-layers 200000` to offload all layers to the GPU.
* Otherwise, start with a low number like `--n-gpu-layers 10` and then gradually increase it until you run out of memory.
To use this feature, you need to manually compile and install `llama-cpp-python` with GPU support.
#### Linux
```
pip uninstall -y llama-cpp-python
CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir
```
#### Windows
```
pip uninstall -y llama-cpp-python
set CMAKE_ARGS="-DLLAMA_CUBLAS=on"
set FORCE_CMAKE=1
pip install llama-cpp-python --no-cache-dir
```
#### macOS
```
pip uninstall -y llama-cpp-python
CMAKE_ARGS="-DLLAMA_METAL=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir
```
Here you can find the different compilation options for OpenBLAS / cuBLAS / CLBlast: https://pypi.org/project/llama-cpp-python/
## Performance
This was the performance of llama-7b int4 on my i5-12400F (cpu only):
> Output generated in 33.07 seconds (6.05 tokens/s, 200 tokens, context 17)
You can change the number of threads with `--threads N`.

43
docs/llama.cpp.md Normal file
View file

@ -0,0 +1,43 @@
# llama.cpp
llama.cpp is the best backend in two important scenarios:
1) You don't have a GPU.
2) You want to run a model that doesn't fit into your GPU.
## Setting up the models
#### Pre-converted
Download the GGUF models directly into your `text-generation-webui/models` folder. It will be a single file.
* Make sure its name ends in `.gguf`.
* `q4_K_M` quantization is recommended.
#### Convert Llama yourself
Follow the instructions in the llama.cpp README to generate a GGUF: https://github.com/ggerganov/llama.cpp#prepare-data--run
## GPU acceleration
Enabled with the `--n-gpu-layers` parameter.
* If you have enough VRAM, use a high number like `--n-gpu-layers 1000` to offload all layers to the GPU.
* Otherwise, start with a low number like `--n-gpu-layers 10` and then gradually increase it until you run out of memory.
This feature works out of the box for NVIDIA GPUs on Linux (amd64) or Windows. For other GPUs, you need to uninstall `llama-cpp-python` with
```
pip uninstall -y llama-cpp-python
```
and then recompile it using the commands here: https://pypi.org/project/llama-cpp-python/
#### macOS
For macOS, these are the commands:
```
pip uninstall -y llama-cpp-python
CMAKE_ARGS="-DLLAMA_METAL=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir
```

View file

@ -22,19 +22,31 @@ from requests.adapters import HTTPAdapter
from tqdm.contrib.concurrent import thread_map from tqdm.contrib.concurrent import thread_map
base = "https://huggingface.co"
class ModelDownloader: class ModelDownloader:
def __init__(self, max_retries=5): def __init__(self, max_retries=5):
self.s = requests.Session() self.session = requests.Session()
if max_retries: if max_retries:
self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries)) self.session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
self.s.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries)) self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None: if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
if os.getenv('HF_TOKEN') is not None:
self.session.headers = {'authorization': f'Bearer {os.getenv("HF_TOKEN")}'}
def sanitize_model_and_branch_names(self, model, branch): def sanitize_model_and_branch_names(self, model, branch):
if model[-1] == '/': if model[-1] == '/':
model = model[:-1] model = model[:-1]
if model.startswith(base + '/'):
model = model[len(base) + 1:]
model_parts = model.split(":")
model = model_parts[0] if len(model_parts) > 0 else model
branch = model_parts[1] if len(model_parts) > 1 else branch
if branch is None: if branch is None:
branch = "main" branch = "main"
else: else:
@ -45,8 +57,7 @@ class ModelDownloader:
return model, branch return model, branch
def get_download_links_from_huggingface(self, model, branch, text_only=False): def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None):
base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}" page = f"/api/models/{model}/tree/{branch}"
cursor = b"" cursor = b""
@ -55,12 +66,12 @@ class ModelDownloader:
classifications = [] classifications = []
has_pytorch = False has_pytorch = False
has_pt = False has_pt = False
# has_ggml = False has_gguf = False
has_safetensors = False has_safetensors = False
is_lora = False is_lora = False
while True: while True:
url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "") url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "")
r = self.s.get(url, timeout=20) r = self.session.get(url, timeout=10)
r.raise_for_status() r.raise_for_status()
content = r.content content = r.content
@ -70,16 +81,19 @@ class ModelDownloader:
for i in range(len(dict)): for i in range(len(dict)):
fname = dict[i]['path'] fname = dict[i]['path']
if specific_file not in [None, ''] and fname != specific_file:
continue
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')): if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
is_lora = True is_lora = True
is_pytorch = re.match("(pytorch|adapter|gptq)_model.*\.bin", fname) is_pytorch = re.match(r"(pytorch|adapter|gptq)_model.*\.bin", fname)
is_safetensors = re.match(".*\.safetensors", fname) is_safetensors = re.match(r".*\.safetensors", fname)
is_pt = re.match(".*\.pt", fname) is_pt = re.match(r".*\.pt", fname)
is_ggml = re.match(".*ggml.*\.bin", fname) is_gguf = re.match(r'.*\.gguf', fname)
is_tokenizer = re.match("(tokenizer|ice).*\.model", fname) is_tokenizer = re.match(r"(tokenizer|ice|spiece).*\.model", fname)
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer is_text = re.match(r".*\.(txt|json|py|md)", fname) or is_tokenizer
if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)): if any((is_pytorch, is_safetensors, is_pt, is_gguf, is_tokenizer, is_text)):
if 'lfs' in dict[i]: if 'lfs' in dict[i]:
sha256.append([fname, dict[i]['lfs']['oid']]) sha256.append([fname, dict[i]['lfs']['oid']])
@ -99,9 +113,9 @@ class ModelDownloader:
elif is_pt: elif is_pt:
has_pt = True has_pt = True
classifications.append('pt') classifications.append('pt')
elif is_ggml: elif is_gguf:
# has_ggml = True has_gguf = True
classifications.append('ggml') classifications.append('gguf')
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
cursor = base64.b64encode(cursor) cursor = base64.b64encode(cursor)
@ -113,12 +127,17 @@ class ModelDownloader:
if classifications[i] in ['pytorch', 'pt']: if classifications[i] in ['pytorch', 'pt']:
links.pop(i) links.pop(i)
return links, sha256, is_lora is_llamacpp = has_gguf and specific_file is not None
return links, sha256, is_lora, is_llamacpp
def get_output_folder(self, model, branch, is_lora, base_folder=None): def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, base_folder=None):
if base_folder is None: if base_folder is None:
base_folder = 'models' if not is_lora else 'loras' base_folder = 'models' if not is_lora else 'loras'
# If the model is of type GGUF, save directly in the base_folder
if is_llamacpp:
return Path(base_folder)
output_folder = f"{'_'.join(model.split('/')[-2:])}" output_folder = f"{'_'.join(model.split('/')[-2:])}"
if branch != 'main': if branch != 'main':
output_folder += f'_{branch}' output_folder += f'_{branch}'
@ -134,7 +153,7 @@ class ModelDownloader:
if output_path.exists() and not start_from_scratch: if output_path.exists() and not start_from_scratch:
# Check if the file has already been downloaded completely # Check if the file has already been downloaded completely
r = self.s.get(url, stream=True, timeout=20) r = self.session.get(url, stream=True, timeout=10)
total_size = int(r.headers.get('content-length', 0)) total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size: if output_path.stat().st_size >= total_size:
return return
@ -143,7 +162,7 @@ class ModelDownloader:
headers = {'Range': f'bytes={output_path.stat().st_size}-'} headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab' mode = 'ab'
with self.s.get(url, stream=True, headers=headers, timeout=20) as r: with self.session.get(url, stream=True, headers=headers, timeout=10) as r:
r.raise_for_status() # Do not continue the download if the request was unsuccessful r.raise_for_status() # Do not continue the download if the request was unsuccessful
total_size = int(r.headers.get('content-length', 0)) total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 1024 # 1MB block_size = 1024 * 1024 # 1MB
@ -155,16 +174,18 @@ class ModelDownloader:
f.write(data) f.write(data)
if total_size != 0 and self.progress_bar is not None: if total_size != 0 and self.progress_bar is not None:
count += len(data) count += len(data)
self.progress_bar(float(count) / float(total_size), f"Downloading {filename}") self.progress_bar(float(count) / float(total_size), f"{filename}")
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1): def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1):
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True) thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=1): def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=1, specific_file=None, is_llamacpp=False):
self.progress_bar = progress_bar self.progress_bar = progress_bar
# Creating the folder and writing the metadata # Create the folder and writing the metadata
output_folder.mkdir(parents=True, exist_ok=True) output_folder.mkdir(parents=True, exist_ok=True)
if not is_llamacpp:
metadata = f'url: https://huggingface.co/{model}\n' \ metadata = f'url: https://huggingface.co/{model}\n' \
f'branch: {branch}\n' \ f'branch: {branch}\n' \
f'download date: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\n' f'download date: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\n'
@ -176,8 +197,11 @@ class ModelDownloader:
metadata += '\n' metadata += '\n'
(output_folder / 'huggingface-metadata.txt').write_text(metadata) (output_folder / 'huggingface-metadata.txt').write_text(metadata)
# Downloading the files if specific_file:
print(f"Downloading {specific_file} to {output_folder}")
else:
print(f"Downloading the model to {output_folder}") print(f"Downloading the model to {output_folder}")
self.start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads) self.start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
def check_model_files(self, model, branch, links, sha256, output_folder): def check_model_files(self, model, branch, links, sha256, output_folder):
@ -213,6 +237,7 @@ if __name__ == '__main__':
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
@ -221,28 +246,29 @@ if __name__ == '__main__':
branch = args.branch branch = args.branch
model = args.MODEL model = args.MODEL
specific_file = args.specific_file
if model is None: if model is None:
print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').") print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').")
sys.exit() sys.exit()
downloader = ModelDownloader(max_retries=args.max_retries) downloader = ModelDownloader(max_retries=args.max_retries)
# Cleaning up the model/branch names # Clean up the model/branch names
try: try:
model, branch = downloader.sanitize_model_and_branch_names(model, branch) model, branch = downloader.sanitize_model_and_branch_names(model, branch)
except ValueError as err_branch: except ValueError as err_branch:
print(f"Error: {err_branch}") print(f"Error: {err_branch}")
sys.exit() sys.exit()
# Getting the download links from Hugging Face # Get the download links from Hugging Face
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only) links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only, specific_file=specific_file)
# Getting the output folder # Get the output folder
output_folder = downloader.get_output_folder(model, branch, is_lora, base_folder=args.output) output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, base_folder=args.output)
if args.check: if args.check:
# Check previously downloaded files # Check previously downloaded files
downloader.check_model_files(model, branch, links, sha256, output_folder) downloader.check_model_files(model, branch, links, sha256, output_folder)
else: else:
# Download files # Download files
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=args.threads) downloader.download_model_files(model, branch, links, sha256, output_folder, specific_file=specific_file, threads=args.threads, is_llamacpp=is_llamacpp)

View file

@ -0,0 +1,96 @@
from functools import partial
import torch
import transformers
import math
from torch.optim.lr_scheduler import LambdaLR
#FPHAM custom training scheduller block - should be extracted to separate file
last_print_label = ''
def _get_fp_cosine_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
global last_print_label
print_label = ''
num_warmup_steps = min(num_warmup_steps,num_firstepoch_steps)
if current_step < num_warmup_steps:
print_label = 'Scheduler: Warmup'
elif current_step < num_firstepoch_steps:
print_label = 'Scheduler: Hold'
else:
print_label = 'Scheduler: Annealing'
if print_label != last_print_label:
print(print_label)
last_print_label = print_label
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
if current_step < num_firstepoch_steps:
return 1.0
progress = float(current_step - num_firstepoch_steps) / float(max(1, num_training_steps - num_firstepoch_steps))
num_cycles = 0.5
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
def custom_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
"""
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_fp_cosine_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_firstepoch_steps = num_firstepoch_steps,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
class FPSchedulerTrainer(transformers.Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
#Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or passed as an argument.
if self.args.lr_scheduler_type == 'cosine':
num_train_epochs = self.args.num_train_epochs
num_warmup_steps=self.args.get_warmup_steps(num_training_steps)
num_firstepoch_steps = math.ceil(num_training_steps/num_train_epochs)
num_warmup_acc = num_warmup_steps*self.args.gradient_accumulation_steps
num_firstepoch_steps_acc = num_firstepoch_steps*self.args.gradient_accumulation_steps
num_training_steps_acc = num_training_steps*self.args.gradient_accumulation_steps
num_warmup_acc_min = min(num_warmup_acc, num_firstepoch_steps_acc)
if num_warmup_acc>num_firstepoch_steps_acc:
print(f"\033[1;31;1mWARNING: The number of warmup steps is set too high! It will be clamped to 1 epoch, essentially going from warmup to annealing.\033[0;37;0m")
print (f"FP Scheduler Warmup: 0-[{num_warmup_acc_min}], Hold [{num_warmup_acc_min}]-{num_firstepoch_steps_acc}, Annealing {num_firstepoch_steps_acc}-{num_training_steps_acc}")
else:
print (f"FP Scheduler Warmup: 0-{num_warmup_acc_min}, Hold {num_warmup_acc_min}-{num_firstepoch_steps_acc}, Annealing {num_firstepoch_steps_acc}-{num_training_steps_acc}")
self.lr_scheduler = custom_scheduler_with_warmup(
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_firstepoch_steps = num_firstepoch_steps,
)
self._created_lr_scheduler = True
return self.lr_scheduler
else:
return super().create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

View file

@ -0,0 +1,62 @@
import os
import json
def create_graph(lora_path, lora_name):
try:
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
peft_model_path = f'{lora_path}/training_graph.json'
image_model_path = f'{lora_path}/training_graph.png'
# Check if the JSON file exists
if os.path.exists(peft_model_path):
# Load data from JSON file
with open(peft_model_path, 'r') as file:
data = json.load(file)
# Extract x, y1, and y2 values
x = [item['epoch'] for item in data]
y1 = [item['learning_rate'] for item in data]
y2 = [item['loss'] for item in data]
# Create the line chart
fig, ax1 = plt.subplots(figsize=(10, 6))
# Plot y1 (learning rate) on the first y-axis
ax1.plot(x, y1, 'b-', label='Learning Rate')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Learning Rate', color='b')
ax1.tick_params('y', colors='b')
# Create a second y-axis
ax2 = ax1.twinx()
# Plot y2 (loss) on the second y-axis
ax2.plot(x, y2, 'r-', label='Loss')
ax2.set_ylabel('Loss', color='r')
ax2.tick_params('y', colors='r')
# Set the y-axis formatter to display numbers in scientific notation
ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
ax1.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
# Add grid
ax1.grid(True)
# Combine the legends for both plots
lines, labels = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines + lines2, labels + labels2, loc='best')
# Set the title
plt.title(f'{lora_name} LR and Loss vs Epoch')
# Save the chart as an image
plt.savefig(image_model_path)
print(f"Graph saved in {image_model_path}")
else:
print(f"File 'training_graph.json' does not exist in the {lora_path}")
except ImportError:
print("matplotlib is not installed. Please install matplotlib to create PNG graphs")

View file

@ -0,0 +1,27 @@
This is an expanded Training tab
- Chunking: precise raw text slicer (PRTS) uses sentence slicing and making sure things are clean on all ends
- overlap chunking - this special overlapping will make additional overlap block based on logical rules (aka no overlap block on hard cut)
- custom scheduler (follow the code to make your own) In LR Scheduler select FP_low_epoch_annealing - this scheduler will keep the LR constant for first epoch then use cosine for the rest - this part would be best to spawn into a new py file
- save loss threshold - will not save the "Save every n steps" checkpoints until this threshold is reached (I definitely don't need multiple checkpoints that are 2.5 loss - I'm usually interested in checkpoints between say 1.5 and 1.9 loss)
- saves graph png file at the end with learning rate and loss per epoch
- adding EOS to each block or to hard cut only
- automatically lowers gradient accumulation if you go overboard and set gradient accumulation that will be higher than actual data - transformers would then throw error (or they used to, not sure if still true) but in any way, it will fix bad data
- turn BOS on and OFF
- target selector
###Notes:
This uses it's own chunking code for raw text based on sentence splitting. This will avoid weird cuts in the chunks and each chunk should now start with sentence and end on some sentence. It works hand in hand with Hard Cut.
A propper use is to structure your text into logical blocks (ideas) separated by three \n then use three \n in hard cut.
This way each chunk will contain only one flow of ideas and not derail in the thoughts.
And Overlapping code will create overlapped blocks on sentence basis too, but not cross hard cut, thus not cross different ideas either.
Does it make any sense? No? Hmmmm...
###Targets
Normal LORA is q, v and that's what you should use.
You can use (q k v o) or (q k v) and it will give you a lot more trainable parameters. The benefit is that you can keep rank lower and still attain the same coherency as q v with high rank. Guanaco has been trained with QLORA and q k v o for example and they swear by it.
I also added k-v-down which is lifted from IA3, which is very odd one to use for LORA, but it created adorable style craziness when training on raw structured text and bringing the loss all the way down to 1.1 . It didn't overfit (q-v would be just writing entire novels at loss 1.1) and it followed the instruction seeping from the previous fine-tuning. YMMW of course.
Using All will train all 7 targets q-k-v-o-up,down, gate - not sure if there is much benefit from attention only qkvo. It sure makes LORA huge. If that's what you like.

View file

@ -0,0 +1,810 @@
import os
os.environ["WANDB_MODE"] = "offline"
# os.environ["WANDB_DISABLED"] = "true"
import json
import math
import random
import shutil
import sys
import threading
import time
import traceback
from datetime import datetime
from pathlib import Path
import gradio as gr
import torch
import transformers
from .custom_scheduler import FPSchedulerTrainer
from .matplotgraph import create_graph
from .train_utils import get_available_loras_local, precise_cut
from datasets import Dataset, load_dataset
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
set_peft_model_state_dict
)
from peft.utils.other import \
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as model_to_lora_modules
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
from modules import shared, utils
from modules.ui import create_refresh_button
from modules.evaluate import (
calculate_perplexity,
generate_markdown_table,
save_past_evaluations
)
from modules.logging_colors import logger
from modules.models import reload_model
from modules.utils import natural_keys
params = {
"display_name": "Training PRO",
"is_tab": True
}
MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()}
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to", "precize_slicing_overlap", "add_eos_token_type", "save_steps_under_loss", "add_bos_token", "training_projection"]
WANT_INTERRUPT = False
train_log = {}
train_template = {}
train_log_graph = []
Lora_sortedByTime = False
train_choices = ["all","q-k-v-o","q-k-v","k-v-down","q-v"]
def ui():
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
tmp = gr.State('')
with gr.Row():
with gr.Column():
gr.Markdown("This is enhanced version of Lora Training with a sentence based RAW text chunking code")
with gr.Row():
with gr.Column(scale=5):
with gr.Row():
copy_from = gr.Dropdown(label='Copy parameters from', value='None', choices=get_available_loras_local(Lora_sortedByTime), elem_classes=['slim-dropdown'])
create_refresh_button(copy_from, lambda: None, lambda: {'choices': get_available_loras_local(Lora_sortedByTime)}, 'refresh-button')
with gr.Column():
sort_byTime = gr.Checkbox(label='Sort list by Date', value=False, info='Sorts Loras by date created.', elem_classes=['no-background'])
with gr.Row():
with gr.Column(scale=5):
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
with gr.Column():
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background'])
with gr.Row():
with gr.Column():
lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.')
lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
with gr.Column():
save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
save_steps_under_loss = gr.Slider(label='Save Loss Threshold', value=1.9, minimum=0.0, maximum=3.0, step=0.1, info='Save checkpoints only if the loss is less or equall Threshold loss. (0 = save all)')
epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='In scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt', 'FP_low_epoch_annealing'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.', elem_classes=['slim-dropdown'])
with gr.Accordion(label='Advanced Options', open=True):
with gr.Row():
with gr.Column():
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
training_projection = gr.Radio(value = train_choices[4], label='LLaMA Target Projections', info='Change the targets (LORA is typically q-v)', choices=train_choices)
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.', elem_classes=['slim-dropdown'])
with gr.Column():
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
add_bos_token = gr.Checkbox(label='Add BOS token', value=True, info="Adds BOS token for each dataset item")
add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each dataset item")
add_eos_token_type = gr.Dropdown(label='EOS placement (raw text)', choices=['Every Block', 'Hard Cut Blocks Only'], value='Every Block', info='', allow_custom_value = False)
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
with gr.Column():
with gr.Tab(label='Formatted Dataset'):
with gr.Row():
format = gr.Dropdown(choices=utils.get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.', elem_classes=['slim-dropdown'])
create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('training/formats', 'json')}, 'refresh-button')
with gr.Row():
dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.', elem_classes=['slim-dropdown'])
create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
with gr.Row():
eval_dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'])
create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
with gr.Tab(label="Raw text file"):
with gr.Row():
raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.', elem_classes=['slim-dropdown'])
create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button')
with gr.Row():
with gr.Column():
precize_slicing_overlap = gr.Checkbox(label='Create Overlapping blocks', value = True)
with gr.Column():
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a cut between logical blocks of text (ex. Ideas or Chapters). Helps prevent unwanted overlap between unrelated ideas.')
min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Text blocks that have less or equal characters than this number.')
with gr.Row():
start_button = gr.Button("Start LoRA Training", variant='primary')
stop_button = gr.Button("Interrupt")
output = gr.Markdown(value="Ready")
with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'):
with gr.Row():
with gr.Column():
models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True)
evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.')
with gr.Row():
with gr.Column():
stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
with gr.Column():
max_length = gr.Slider(label='max_length', minimum=0, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
with gr.Row():
start_current_evaluation = gr.Button("Evaluate loaded model")
start_evaluation = gr.Button("Evaluate selected models")
stop_evaluation = gr.Button("Interrupt")
with gr.Column():
evaluation_log = gr.Markdown(value='')
evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
with gr.Row():
save_comments = gr.Button('Save comments', elem_classes="small-button")
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
# Training events
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to, precize_slicing_overlap, add_eos_token_type, save_steps_under_loss, add_bos_token, training_projection]
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
start_button.click(do_train, all_params, output)
stop_button.click(do_interrupt, None, None, queue=False)
higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
# Evaluation events. For some reason, the interrupt event
# doesn't work with the .then() syntax, so I write them one
# by one in this ugly but functional way.
ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
start_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
start_current_evaluation.click(lambda: ['current model'], None, tmp)
ev_cur = start_current_evaluation.click(calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
stop_evaluation.click(None, None, None, cancels=[ev, ev_cur], queue=False)
refresh_table.click(generate_markdown_table, None, evaluation_table, show_progress=True)
save_comments.click(
save_past_evaluations, evaluation_table, None).then(
lambda: "Comments saved.", None, evaluation_log, show_progress=False)
def reload_lora():
global Lora_sortedByTime
return gr.Dropdown.update(choices=get_available_loras_local(Lora_sortedByTime))
def global_lora_time(sort_byTime):
global Lora_sortedByTime
Lora_sortedByTime = sort_byTime
sort_byTime.change(global_lora_time, sort_byTime, None).then(reload_lora,None,copy_from)
def do_interrupt():
global WANT_INTERRUPT
WANT_INTERRUPT = True
def do_copy_params(lora_name: str, *args):
f_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json"
if Path(f_name).is_file():
with open(f_name, 'r', encoding='utf-8') as format_file:
params: dict[str, str] = json.load(format_file)
else:
params = {}
result = list()
for i in range(0, len(PARAMETERS)):
key = PARAMETERS[i]
if key in params:
result.append(params[key])
else:
result.append(args[i])
return result
def change_rank_limit(use_higher_ranks: bool):
mult = 2 if use_higher_ranks else 1
return {"maximum": 1024 * mult, "__type__": "update"}, {"maximum": 2048 * mult, "__type__": "update"}
def clean_path(base_path: str, path: str):
"""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
path = path.replace('\\', '/').replace('..', '_')
if base_path is None:
return path
return f'{Path(base_path).absolute()}/{path}'
def backup_adapter(input_folder):
# Get the creation date of the file adapter_model.bin
try:
adapter_file = Path(f"{input_folder}/adapter_model.bin")
if adapter_file.is_file():
logger.info("Backing up existing LoRA adapter...")
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
# Create the new subfolder
subfolder_path = Path(f"{input_folder}/{creation_date_str}")
subfolder_path.mkdir(parents=True, exist_ok=True)
# Check if the file already exists in the subfolder
backup_adapter_file = Path(f"{input_folder}/{creation_date_str}/adapter_model.bin")
if backup_adapter_file.is_file():
print(" - Backup already exists. Skipping backup process.")
return
# Copy existing files to the new subfolder
existing_files = Path(input_folder).iterdir()
for file in existing_files:
if file.is_file():
shutil.copy2(file, subfolder_path)
except Exception as e:
print("An error occurred in backup_adapter:", str(e))
def calc_trainable_parameters(model):
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
all_param += num_params
if param.requires_grad:
trainable_params += num_params
return trainable_params, all_param
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str, precize_slicing_overlap: bool, add_eos_token_type: str, save_steps_under_loss: float, add_bos_token: bool, training_projection: str):
if shared.args.monkey_patch:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
replace_peft_model_with_int4_lora_model
)
replace_peft_model_with_int4_lora_model()
global WANT_INTERRUPT
WANT_INTERRUPT = False
# == Input validation / processing ==
yield "Preparing the input..."
lora_file_path = clean_path(None, lora_name)
if lora_file_path.strip() == '':
yield "Missing or invalid LoRA file name input."
return
lora_file_path = f"{Path(shared.args.lora_dir)}/{lora_file_path}"
actual_lr = float(learning_rate)
model_type = type(shared.model).__name__
if model_type in MODEL_CLASSES:
model_id = MODEL_CLASSES[model_type]
else:
model_id = "llama"
if model_type == "PeftModelForCausalLM":
if len(shared.lora_names) > 0:
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
logger.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
else:
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
logger.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
else:
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
logger.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
time.sleep(5)
if shared.args.loader == 'GPTQ-for-LLaMa' and not shared.args.monkey_patch:
yield "LoRA training with GPTQ-for-LLaMa requires loading with `--monkey-patch`"
return
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
yield "Cannot input zeroes."
return
gradient_accumulation_steps = batch_size // micro_batch_size
shared.tokenizer.pad_token_id = 0
shared.tokenizer.padding_side = "left"
def encode(text, prepend_bos_token):
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
# Check if the first two tokens are BOS
if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]:
result = result[1:]
if not prepend_bos_token and result[0] == shared.tokenizer.bos_token_id:
result = result[1:]
return result
def tokenize(prompt, append_eos_token=False, prepend_bos_token = False):
if train_only_after == '' or train_only_after not in prompt:
input_ids = encode(prompt, prepend_bos_token)
if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len:
input_ids.append(shared.tokenizer.eos_token_id)
input_ids = [shared.tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids
labels = [1] * len(input_ids)
else:
ind = prompt.index(train_only_after) + len(train_only_after)
before_tokens = encode(prompt[:ind], prepend_bos_token)
after_tokens = encode(prompt[ind:], False)
if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id:
after_tokens.append(shared.tokenizer.eos_token_id)
full_length = len(after_tokens) + len(before_tokens)
if full_length > cutoff_len:
after_tokens = after_tokens[:cutoff_len - len(before_tokens)]
else:
before_tokens = [shared.tokenizer.pad_token_id] * (cutoff_len - full_length) + before_tokens
input_ids = before_tokens + after_tokens
labels = [-100] * len(before_tokens) + [1] * len(after_tokens)
input_ids = torch.tensor(input_ids)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": input_ids.ne(shared.tokenizer.pad_token_id),
}
train_template.clear()
print(f"*** LoRA: {lora_name} ***")
# END OF FPHAM SENTENCE SPLIT functions ===================
# == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']:
train_template["template_type"] = "raw_text"
logger.info("Loading raw text file dataset...")
fullpath = clean_path('training/datasets', f'{raw_text_file}')
fullpath = Path(fullpath)
if fullpath.is_dir():
logger.info('Training path directory {}'.format(raw_text_file))
raw_text = ""
file_paths = sorted(fullpath.glob('*.txt'), key=lambda path: natural_keys(path.name))
for file_path in file_paths:
if file_path.is_file():
with file_path.open('r', encoding='utf-8') as file:
raw_text += file.read().replace('\r', '')
logger.info(f"Loaded training file: {file_path.name}")
else:
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
raw_text = file.read().replace('\r', '')
# FPHAM PRECISE SLICING
if min_chars<0:
min_chars = 0
add_EOS_to_all = add_eos_token and add_eos_token_type == 'Every Block'
add_EOS_to_HC = add_eos_token and add_eos_token_type != 'Every Block'
#print (f"add_eos_token {add_eos_token}, add_EOS_to_all {add_EOS_to_all}, add_EOS_to_HC {add_EOS_to_HC}")
# == New more precise slicing on sentence boundary ==
text_chunks = precise_cut(raw_text, precize_slicing_overlap, min_chars, add_EOS_to_HC, cutoff_len, hard_cut_string)
train_data = Dataset.from_list([tokenize(x, add_EOS_to_all, add_bos_token) for x in text_chunks])
if add_EOS_to_all:
print(f"Added EOS to {len(text_chunks)} blocks")
del text_chunks
eval_data = None
else:
if dataset in ['None', '']:
yield "Missing dataset choice input, cannot continue."
return
if format in ['None', '']:
yield "Missing format choice input, cannot continue."
return
train_template["template_type"] = "dataset"
with open(clean_path('training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
format_data: dict[str, str] = json.load(formatFile)
# == store training prompt ==
for _, value in format_data.items():
prompt_key = f"template_{len(train_template)}"
train_template[prompt_key] = value
def generate_prompt(data_point: dict[str, str]):
for options, data in format_data.items():
if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)):
for key, val in data_point.items():
if type(val) is str:
data = data.replace(f'%{key}%', val)
return data
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
def generate_and_tokenize_prompt(data_point):
prompt = generate_prompt(data_point)
return tokenize(prompt, add_eos_token, add_bos_token)
logger.info("Loading JSON datasets...")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
print(f"BOS: {add_bos_token} EOS: {add_eos_token}")
if eval_dataset == 'None':
eval_data = None
else:
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
# == We MUST reload model if it went through any previous training, even failed one ==
if shared.model_dirty_from_training:
selected_model = shared.model_name
if selected_model:
print("\033[1;31;1m(Model has been modified by previous training, it needs to be reloaded...)\033[0;37;0m")
try:
yield f"Reloading {selected_model}..."
reload_model()
if shared.model is not None:
print("Model reloaded OK, continue with training.")
else:
return f"Failed to load {selected_model}."
except:
exc = traceback.format_exc()
logger.error('Failed to reload the model.')
print(exc)
return exc.replace('\n', '\n\n')
# == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
logger.info("Getting model ready...")
prepare_model_for_kbit_training(shared.model)
# base model is now frozen and should not be reused for any other LoRA training than this one
shared.model_dirty_from_training = True
if training_projection==train_choices[0]:
model_to_lora_modules["llama"] = ["gate_proj","down_proj","up_proj","q_proj","k_proj","v_proj","o_proj"]
elif training_projection==train_choices[1]:
model_to_lora_modules["llama"] = ["q_proj","k_proj", "v_proj", "o_proj"]
elif training_projection==train_choices[2]:
model_to_lora_modules["llama"] = ["q_proj","k_proj", "v_proj"]
elif training_projection==train_choices[3]:
model_to_lora_modules["llama"] = ["k_proj", "v_proj", "down_proj"]
else:
model_to_lora_modules["llama"] = ["q_proj", "v_proj"]
logger.info("Preparing for training...")
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=model_to_lora_modules[model_id],
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
# == Backup the existing adapter ==
if not always_override:
backup_adapter(lora_file_path)
# == get model trainable params
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
try:
logger.info("Creating LoRA model...")
lora_model = get_peft_model(shared.model, config)
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
logger.info("Loading existing LoRA data...")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
set_peft_model_state_dict(lora_model, state_dict_peft)
except:
yield traceback.format_exc().replace('\n', '\n\n')
return
if shared.args.monkey_patch:
from alpaca_lora_4bit.autograd_4bit import Autograd4bitQuantLinear
from alpaca_lora_4bit.models import Linear4bitLt
for _, m in lora_model.named_modules():
if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt):
if m.is_v1_model:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
class Tracked():
def __init__(self):
self.current_steps = 0
self.max_steps = 0
self.did_save = False
tracked = Tracked()
actual_save_steps = math.ceil(save_steps / gradient_accumulation_steps)
class Callbacks(transformers.TrainerCallback):
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
tracked.current_steps = state.global_step * gradient_accumulation_steps
tracked.max_steps = state.max_steps * gradient_accumulation_steps
if WANT_INTERRUPT:
control.should_epoch_stop = True
control.should_training_stop = True
elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
current_loss = float(train_log.get('loss', 0.0))
if current_loss <= save_steps_under_loss or save_steps_under_loss==0.0:
lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/")
print(f"\033[1;30;40mStep: {tracked.current_steps:6} \033[0;37;0m Checkpoint-{tracked.current_steps} saved")
# Save log
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_log.json", 'w', encoding='utf-8') as file:
json.dump(train_log, file, indent=2)
# == Save training prompt ==
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_prompt.json", 'w', encoding='utf-8') as file:
json.dump(train_template, file, indent=2)
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
tracked.current_steps += 1
if WANT_INTERRUPT:
control.should_epoch_stop = True
control.should_training_stop = True
def on_log(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, logs, **kwargs):
train_log.update(logs)
train_log.update({"current_steps": tracked.current_steps})
if WANT_INTERRUPT:
print("\033[1;31;1mInterrupted by user\033[0;37;0m")
print(f"\033[1;30;40mStep: {tracked.current_steps:6} \033[0;37;0m", end='')
entry = {
'current_steps': int(train_log.get('current_steps',0)),
'loss': float(train_log.get('loss', 0.0)),
'learning_rate': float(train_log.get('learning_rate', 0.0)),
'epoch': float(train_log.get('epoch', 0.0))
}
# Add the entry to the continuous log
train_log_graph.append(entry)
# Save the graph log for now, we can later generate full graph
with open(f"{lora_file_path}/training_graph.json", 'w') as file:
json.dump(train_log_graph, file, indent=4)
if 'loss' in logs:
loss = float(logs['loss'])
if loss <= stop_at_loss:
control.should_epoch_stop = True
control.should_training_stop = True
print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m")
# FPHAM SAMPLE REQ Transformers error handling
sample_req = int(train_data.num_rows)//micro_batch_size
if sample_req < gradient_accumulation_steps:
print(f"\033[1;31;1mWARNING: Current gradient accumulation is too high for the amount of training data.\033[0;37;0m")
print(f"Gradient accumulation: {gradient_accumulation_steps} should be less than: {sample_req}. \033[1;31;1mThis could crash Accelerate/Transformers\033[0;37;0m")
min_batchSize = sample_req*micro_batch_size
print(f"Preferable fix: \033[1;31;1mIncrease the size of dataset\033[0;37;0m")
print(f"... or Decrerase Batch Size \033[1;31;1m{batch_size}\033[0;37;0m to below {min_batchSize}")
gradient_accumulation_steps = max(1,sample_req-1)
print(f"Last resort fix for this run: Lowering Gradient accumulation to {gradient_accumulation_steps}. [Good luck]")
else:
print(f"Data Size Check: Gradient accumulation: {gradient_accumulation_steps} <= Data/Batch {sample_req} ... [OK]")
#END OF FPHAM SAMPLE REQ
# FPHAM Custom Scheduler ==
custom_scheduller = False
lr_scheduler_type_arg = lr_scheduler_type
if lr_scheduler_type == 'FP_low_epoch_annealing':
custom_scheduller = True
lr_scheduler_type_arg = 'cosine'
args=transformers.TrainingArguments(
report_to=report_to if report_to != "None" else None,
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
num_train_epochs=epochs,
learning_rate=actual_lr,
fp16=False if shared.args.cpu else True,
optim=optimizer,
logging_steps=1,
evaluation_strategy="steps" if eval_data is not None else "no",
eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None,
save_strategy="steps" if eval_data is not None else "no",
output_dir=lora_file_path,
lr_scheduler_type=lr_scheduler_type_arg,
load_best_model_at_end=eval_data is not None,
# TODO: Enable multi-device support
ddp_find_unused_parameters=None,
no_cuda=shared.args.cpu,
)
if custom_scheduller:
trainer = FPSchedulerTrainer(
model=lora_model,
train_dataset=train_data,
eval_dataset=eval_data,
args=args,
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
callbacks=list([Callbacks()])
)
else:
trainer = transformers.Trainer(
model=lora_model,
train_dataset=train_data,
eval_dataset=eval_data,
args=args,
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
callbacks=list([Callbacks()])
)
# END OF FPHAM CUSTOM SCHEDULER
lora_model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
lora_model = torch.compile(lora_model)
# == Save parameters for reuse ==
with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file:
vars = locals()
json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2)
# == Save training prompt ==
with open(f"{lora_file_path}/training_prompt.json", 'w', encoding='utf-8') as file:
json.dump(train_template, file, indent=2)
# == Main run and monitor loop ==
logger.info("Starting training...")
yield "Starting..."
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
projections_string = ", ".join([projection.replace("_proj", "") for projection in model_to_lora_modules[model_id]])
print(f"Training '{model_id}' model using ({projections_string}) projections")
if lora_all_param > 0:
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
train_log.update({"base_model_name": shared.model_name})
train_log.update({"base_model_class": shared.model.__class__.__name__})
train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})
train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)})
train_log.update({"projections": projections_string})
if stop_at_loss > 0:
print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")
if WANT_INTERRUPT:
yield "Interrupted before start."
return
def log_train_dataset(trainer):
decoded_entries = []
# Try to decode the entries and write the log file
try:
# Iterate over the first 10 elements in the dataset (or fewer if there are less than 10)
for i in range(min(10, len(trainer.train_dataset))):
decoded_text = shared.tokenizer.decode(trainer.train_dataset[i]['input_ids'])
decoded_entries.append({"value": decoded_text})
# Write the log file
Path('logs').mkdir(exist_ok=True)
with open(Path('logs/train_dataset_sample.json'), 'w') as json_file:
json.dump(decoded_entries, json_file, indent=4)
logger.info("Log file 'train_dataset_sample.json' created in the 'logs' directory.")
except Exception as e:
logger.error(f"Failed to create log file due to error: {e}")
def threaded_run():
log_train_dataset(trainer)
trainer.train()
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
lora_model.save_pretrained(lora_file_path)
logger.info("LoRA training run is completed and saved.")
# Save log
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
json.dump(train_log, file, indent=2)
thread = threading.Thread(target=threaded_run)
thread.start()
last_step = 0
start_time = time.perf_counter()
while thread.is_alive():
time.sleep(0.5)
if WANT_INTERRUPT:
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
elif tracked.current_steps != last_step:
last_step = tracked.current_steps
time_elapsed = time.perf_counter() - start_time
if time_elapsed <= 0:
timer_info = ""
total_time_estimate = 999
else:
its = tracked.current_steps / time_elapsed
if its > 1:
timer_info = f"`{its:.2f}` it/s"
else:
timer_info = f"`{1.0/its:.2f}` s/it"
total_time_estimate = (1.0 / its) * (tracked.max_steps)
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
# Saving in the train thread might fail if an error occurs, so save here if so.
if not tracked.did_save:
logger.info("Training complete, saving...")
lora_model.save_pretrained(lora_file_path)
if WANT_INTERRUPT:
logger.info("Training interrupted.")
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`."
else:
logger.info("Training complete!")
yield f"Done! LoRA saved to `{lora_file_path}`.\n\nBefore testing your new LoRA, make sure to first reload the model, as it is currently dirty from training."
create_graph(lora_file_path, lora_name)
def format_time(seconds: float):
if seconds < 120:
return f"`{seconds:.0f}` seconds"
minutes = seconds / 60
if minutes < 120:
return f"`{minutes:.0f}` minutes"
hours = minutes / 60
return f"`{hours:.0f}` hours"

View file

@ -0,0 +1,192 @@
import os
from modules import shared, utils
from pathlib import Path
import json
def list_subfoldersByTime(directory):
if not directory.endswith('/'):
directory += '/'
subfolders = []
path = directory
name_list = os.listdir(path)
full_list = [os.path.join(path,i) for i in name_list]
time_sorted_list = sorted(full_list, key=os.path.getmtime,reverse=True)
for entry in time_sorted_list:
if os.path.isdir(entry):
entry_str = f"{entry}" # Convert entry to a string
full_path = entry_str
entry_str = entry_str.replace('\\','/')
entry_str = entry_str.replace(f"{directory}", "") # Remove directory part
subfolders.append(entry_str)
return subfolders
def get_available_loras_local(_sortedByTime):
model_dir = shared.args.lora_dir # Update with the appropriate directory path
subfolders = []
if _sortedByTime:
subfolders = list_subfoldersByTime(model_dir)
else:
subfolders = utils.get_available_loras()
return subfolders
# FPHAM SPLIT BY SENTENCE BLOCK ===============
def split_sentences(text: str, cutoff_len: int):
sentences = []
sentence = ''
delimiters = ['. ', '? ', '! ', '... ', '.\n', '?\n', '!\n','...\n','</s>','<//>']
abbreviations = ['Mr. ', 'Mrs. ', 'Dr. ', 'Ms. ', 'St. ', 'Prof. ', 'Jr. ', 'Ltd. ', 'Capt. ', 'Col. ', 'Gen. ', 'Ave. ', 'Blvd. ', 'Co. ', 'Corp. ', 'Dept. ', 'Est. ', 'Gov. ', 'Inc. ', 'Ph.D. ', 'Univ. ']
errors = 0
max_cut = cutoff_len-1
prev_char = ''
for char in text:
sentence += char
if (any(sentence.endswith(delimiter) for delimiter in delimiters) and
not (prev_char.isupper() and len(sentence) >= 3 and sentence[-3] != ' ') and
not any(sentence.endswith(abbreviation) for abbreviation in abbreviations)):
tokens = shared.tokenizer.encode(sentence)
if len(tokens) > max_cut:
tokens = tokens[:max_cut]
sentence = shared.tokenizer.decode(tokens, skip_special_tokens=True)
errors = errors + 1
sentences.append({'text': sentence, 'size': len(tokens)})
sentence = ''
prev_char = char
if sentence:
tokens = shared.tokenizer.encode(sentence)
if len(tokens) > max_cut:
tokens = tokens[:max_cut]
sentence = shared.tokenizer.decode(tokens, skip_special_tokens=True)
errors = errors + 1
sentences.append({'text': sentence, 'size': len(tokens)})
if errors > 0:
print(f"Trimmed sentences beyond Cutoff Length: {errors}")
return sentences
# The goal of following code is to create blocks of text + overlapping blocks while:
# respects sentence boundaries
# always uses all the text
# hard cut defined by hard_cut_string or </s> will always end at the end of data block
# no overlapping blocks will be created across hard cut or across </s> token
def precise_cut(text: str, overlap: bool, min_chars_cut: int, eos_to_hc: bool, cutoff_len: int, hard_cut_string: str):
debug_slicer = False
EOSX_str = '<//>' #hardcut placeholder
EOS_str = '</s>'
print("Precise raw text slicer: ON")
cut_string = hard_cut_string.replace('\\n', '\n')
text = text.replace(cut_string, EOSX_str)
sentences = split_sentences(text, cutoff_len)
print(f"Sentences: {len(sentences)}")
sentencelist = []
currentSentence = ''
totalLength = 0
max_cut = cutoff_len-1
half_cut = cutoff_len//2
halfcut_length = 0
edgeindex = []
half_index = 0
for index, item in enumerate(sentences):
if halfcut_length+ item['size'] < half_cut:
halfcut_length += item['size']
half_index = index
else:
edgeindex.append(half_index)
halfcut_length = -2 * max_cut
if totalLength + item['size'] < max_cut and not currentSentence.endswith(EOSX_str):
currentSentence += item['text']
totalLength += item['size']
else:
if len(currentSentence.strip()) > min_chars_cut:
sentencelist.append(currentSentence.strip())
currentSentence = item['text']
totalLength = item['size']
halfcut_length = item['size']
if len(currentSentence.strip()) > min_chars_cut:
sentencelist.append(currentSentence.strip())
unique_blocks = len(sentencelist)
print(f"Text Blocks: {unique_blocks}")
#overlap strategies:
# don't overlap across HARD CUT (EOSX)
if overlap:
for edge_idx in edgeindex:
currentSentence = ''
totalLength = 0
for item in sentences[edge_idx:]:
if totalLength + item['size'] < max_cut:
currentSentence += item['text']
totalLength += item['size']
else:
#if by chance EOSX is at the end then it's acceptable
if currentSentence.endswith(EOSX_str) and len(currentSentence.strip()) > min_chars_cut:
sentencelist.append(currentSentence.strip())
# otherwise don't cross hard cut
elif EOSX_str not in currentSentence and len(currentSentence.strip()) > min_chars_cut:
sentencelist.append(currentSentence.strip())
currentSentence = ''
totalLength = 0
break
print(f"+ Overlapping blocks: {len(sentencelist)-unique_blocks}")
num_EOS = 0
for i in range(len(sentencelist)):
if eos_to_hc:
sentencelist[i] = sentencelist[i].replace(EOSX_str, EOS_str)
else:
sentencelist[i] = sentencelist[i].replace(EOSX_str, '')
#someone may have had stop strings in the raw text...
sentencelist[i] = sentencelist[i].replace("</s></s>", EOS_str)
num_EOS += sentencelist[i].count(EOS_str)
if num_EOS > 0:
print(f"+ EOS count: {num_EOS}")
#final check for useless lines
sentencelist = [item for item in sentencelist if item.strip() != "</s>"]
sentencelist = [item for item in sentencelist if item.strip() != ""]
if debug_slicer:
# Write the log file
Path('logs').mkdir(exist_ok=True)
sentencelist_dict = {index: sentence for index, sentence in enumerate(sentencelist)}
output_file = "logs/sentencelist.json"
with open(output_file, 'w') as f:
json.dump(sentencelist_dict, f,indent=2)
return sentencelist

View file

@ -7,10 +7,12 @@ from modules import shared
from modules.chat import generate_chat_reply from modules.chat import generate_chat_reply
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model from modules.models import load_model, unload_model
from modules.models_settings import (get_model_settings_from_yamls, from modules.models_settings import get_model_metadata, update_model_parameters
update_model_parameters) from modules.text_generation import (
from modules.text_generation import (encode, generate_reply, encode,
stop_everything_event) generate_reply,
stop_everything_event
)
from modules.utils import get_available_models from modules.utils import get_available_models
@ -127,8 +129,8 @@ class Handler(BaseHTTPRequestHandler):
shared.model_name = model_name shared.model_name = model_name
unload_model() unload_model()
model_settings = get_model_settings_from_yamls(shared.model_name) model_settings = get_model_metadata(shared.model_name)
shared.settings.update(model_settings) shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings})
update_model_parameters(model_settings, initial=True) update_model_parameters(model_settings, initial=True)
if shared.settings['mode'] != 'instruct': if shared.settings['mode'] != 'instruct':
@ -195,7 +197,7 @@ class Handler(BaseHTTPRequestHandler):
super().end_headers() super().end_headers()
def _run_server(port: int, share: bool = False): def _run_server(port: int, share: bool = False, tunnel_id=str):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1' address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
server = ThreadingHTTPServer((address, port), Handler) server = ThreadingHTTPServer((address, port), Handler)
@ -205,7 +207,7 @@ def _run_server(port: int, share: bool = False):
if share: if share:
try: try:
try_start_cloudflared(port, max_attempts=3, on_start=on_start) try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
except Exception: except Exception:
pass pass
else: else:
@ -215,5 +217,5 @@ def _run_server(port: int, share: bool = False):
server.serve_forever() server.serve_forever()
def start_server(port: int, share: bool = False): def start_server(port: int, share: bool = False, tunnel_id=str):
Thread(target=_run_server, args=[port, share], daemon=True).start() Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start()

View file

@ -1,2 +1,2 @@
flask_cloudflared==0.0.12 flask_cloudflared==0.0.14
websockets==11.0.2 websockets==11.0.2

View file

@ -1,8 +1,13 @@
import time
import extensions.api.blocking_api as blocking_api import extensions.api.blocking_api as blocking_api
import extensions.api.streaming_api as streaming_api import extensions.api.streaming_api as streaming_api
from modules import shared from modules import shared
def setup(): def setup():
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api) blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)
streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api) if shared.args.public_api:
time.sleep(5)
streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)

View file

@ -2,12 +2,15 @@ import asyncio
import json import json
from threading import Thread from threading import Thread
from websockets.server import serve from extensions.api.util import (
build_parameters,
from extensions.api.util import build_parameters, try_start_cloudflared, with_api_lock try_start_cloudflared,
with_api_lock
)
from modules import shared from modules import shared
from modules.chat import generate_chat_reply from modules.chat import generate_chat_reply
from modules.text_generation import generate_reply from modules.text_generation import generate_reply
from websockets.server import serve
PATH = '/api/v1/stream' PATH = '/api/v1/stream'
@ -99,7 +102,7 @@ async def _run(host: str, port: int):
await asyncio.Future() # run forever await asyncio.Future() # run forever
def _run_server(port: int, share: bool = False): def _run_server(port: int, share: bool = False, tunnel_id=str):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1' address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
def on_start(public_url: str): def on_start(public_url: str):
@ -108,7 +111,7 @@ def _run_server(port: int, share: bool = False):
if share: if share:
try: try:
try_start_cloudflared(port, max_attempts=3, on_start=on_start) try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
except Exception as e: except Exception as e:
print(e) print(e)
else: else:
@ -117,5 +120,5 @@ def _run_server(port: int, share: bool = False):
asyncio.run(_run(host=address, port=port)) asyncio.run(_run(host=address, port=port))
def start_server(port: int, share: bool = False): def start_server(port: int, share: bool = False, tunnel_id=str):
Thread(target=_run_server, args=[port, share], daemon=True).start() Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start()

View file

@ -10,7 +10,6 @@ from modules import shared
from modules.chat import load_character_memoized from modules.chat import load_character_memoized
from modules.presets import load_preset_memoized from modules.presets import load_preset_memoized
# We use a thread local to store the asyncio lock, so that each thread # We use a thread local to store the asyncio lock, so that each thread
# has its own lock. This isn't strictly necessary, but it makes it # has its own lock. This isn't strictly necessary, but it makes it
# such that if we can support multiple worker threads in the future, # such that if we can support multiple worker threads in the future,
@ -22,6 +21,8 @@ def build_parameters(body, chat=False):
generate_params = { generate_params = {
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))), 'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)),
'max_tokens_second': int(body.get('max_tokens_second', 0)),
'do_sample': bool(body.get('do_sample', True)), 'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)), 'temperature': float(body.get('temperature', 0.5)),
'top_p': float(body.get('top_p', 1)), 'top_p': float(body.get('top_p', 1)),
@ -43,9 +44,12 @@ def build_parameters(body, chat=False):
'mirostat_mode': int(body.get('mirostat_mode', 0)), 'mirostat_mode': int(body.get('mirostat_mode', 0)),
'mirostat_tau': float(body.get('mirostat_tau', 5)), 'mirostat_tau': float(body.get('mirostat_tau', 5)),
'mirostat_eta': float(body.get('mirostat_eta', 0.1)), 'mirostat_eta': float(body.get('mirostat_eta', 0.1)),
'guidance_scale': float(body.get('guidance_scale', 1)),
'negative_prompt': str(body.get('negative_prompt', '')),
'seed': int(body.get('seed', -1)), 'seed': int(body.get('seed', -1)),
'add_bos_token': bool(body.get('add_bos_token', True)), 'add_bos_token': bool(body.get('add_bos_token', True)),
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))), 'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
'custom_token_bans': str(body.get('custom_token_bans', '')),
'ban_eos_token': bool(body.get('ban_eos_token', False)), 'ban_eos_token': bool(body.get('ban_eos_token', False)),
'skip_special_tokens': bool(body.get('skip_special_tokens', True)), 'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
'custom_stopping_strings': '', # leave this blank 'custom_stopping_strings': '', # leave this blank
@ -59,34 +63,37 @@ def build_parameters(body, chat=False):
if chat: if chat:
character = body.get('character') character = body.get('character')
instruction_template = body.get('instruction_template') instruction_template = body.get('instruction_template', shared.settings['instruction_template'])
name1, name2, _, greeting, context, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), shared.settings['name2'], instruct=False) if str(instruction_template) == "None":
instruction_template = "Vicuna-v1.1"
if str(character) == "None":
character = "Assistant"
name1, name2, _, greeting, context, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), '', instruct=False)
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True) name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True)
generate_params.update({ generate_params.update({
'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])),
'chat_generation_attempts': int(body.get('chat_generation_attempts', shared.settings['chat_generation_attempts'])),
'mode': str(body.get('mode', 'chat')), 'mode': str(body.get('mode', 'chat')),
'name1': name1, 'name1': str(body.get('name1', name1)),
'name2': name2, 'name2': str(body.get('name2', name2)),
'context': context, 'context': str(body.get('context', context)),
'greeting': greeting, 'greeting': str(body.get('greeting', greeting)),
'name1_instruct': name1_instruct, 'name1_instruct': str(body.get('name1_instruct', name1_instruct)),
'name2_instruct': name2_instruct, 'name2_instruct': str(body.get('name2_instruct', name2_instruct)),
'context_instruct': context_instruct, 'context_instruct': str(body.get('context_instruct', context_instruct)),
'turn_template': turn_template, 'turn_template': str(body.get('turn_template', turn_template)),
'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])), 'chat-instruct_command': str(body.get('chat_instruct_command', body.get('chat-instruct_command', shared.settings['chat-instruct_command']))),
'history': body.get('history', {'internal': [], 'visible': []}) 'history': body.get('history', {'internal': [], 'visible': []})
}) })
return generate_params return generate_params
def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): def try_start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
Thread(target=_start_cloudflared, args=[ Thread(target=_start_cloudflared, args=[
port, max_attempts, on_start], daemon=True).start() port, tunnel_id, max_attempts, on_start], daemon=True).start()
def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
try: try:
from flask_cloudflared import _run_cloudflared from flask_cloudflared import _run_cloudflared
except ImportError: except ImportError:
@ -96,6 +103,9 @@ def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Call
for _ in range(max_attempts): for _ in range(max_attempts):
try: try:
if tunnel_id is not None:
public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id)
else:
public_url = _run_cloudflared(port, port + 1) public_url = _run_cloudflared(port, port + 1)
if on_start: if on_start:

View file

@ -1 +1 @@
elevenlabs==0.2.* elevenlabs==0.2.24

View file

@ -1,10 +1,12 @@
import html
import re import re
from pathlib import Path from pathlib import Path
import elevenlabs import elevenlabs
import gradio as gr import gradio as gr
from modules import chat, shared from modules import chat, shared, ui_chat
from modules.logging_colors import logger
from modules.utils import gradio from modules.utils import gradio
params = { params = {
@ -13,10 +15,12 @@ params = {
'selected_voice': 'None', 'selected_voice': 'None',
'autoplay': False, 'autoplay': False,
'show_text': True, 'show_text': True,
'model': 'eleven_monolingual_v1',
} }
voices = None voices = None
wav_idx = 0 wav_idx = 0
LANG_MODELS = ['eleven_monolingual_v1', 'eleven_multilingual_v1']
def update_api_key(key): def update_api_key(key):
@ -108,7 +112,7 @@ def output_modifier(string):
output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.mp3'.format(wav_idx)) output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.mp3'.format(wav_idx))
print(f'Outputting audio to {str(output_file)}') print(f'Outputting audio to {str(output_file)}')
try: try:
audio = elevenlabs.generate(text=string, voice=params['selected_voice'], model="eleven_monolingual_v1") audio = elevenlabs.generate(text=html.unescape(string), voice=params['selected_voice'], model=params['model'])
elevenlabs.save(audio, str(output_file)) elevenlabs.save(audio, str(output_file))
autoplay = 'autoplay' if params['autoplay'] else '' autoplay = 'autoplay' if params['autoplay'] else ''
@ -132,6 +136,11 @@ def ui():
global voices global voices
if not voices: if not voices:
voices = refresh_voices() voices = refresh_voices()
selected = params['selected_voice']
if selected == 'None':
params['selected_voice'] = voices[0]
elif selected not in voices:
logger.error(f'Selected voice {selected} not available, switching to {voices[0]}')
params['selected_voice'] = voices[0] params['selected_voice'] = voices[0]
# Gradio elements # Gradio elements
@ -145,14 +154,20 @@ def ui():
refresh = gr.Button(value='Refresh') refresh = gr.Button(value='Refresh')
with gr.Row(): with gr.Row():
if params['api_key']:
api_key = gr.Textbox(value=params['api_key'], label='API Key')
update_api_key(params['api_key'])
else:
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key') api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
with gr.Row():
model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model')
with gr.Row(): with gr.Row():
convert = gr.Button('Permanently replace audios with the message texts') convert = gr.Button('Permanently replace audios with the message texts')
convert_cancel = gr.Button('Cancel', visible=False) convert_cancel = gr.Button('Cancel', visible=False)
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False) convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
if shared.is_chat():
# Convert history with confirmation # Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel] convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
@ -160,7 +175,7 @@ def ui():
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then( lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
remove_tts_from_history, gradio('history'), gradio('history')).then( remove_tts_from_history, gradio('history'), gradio('history')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then( chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
chat.redraw_html, shared.reload_inputs, gradio('display')) chat.redraw_html, gradio(ui_chat.reload_arr), gradio('display'))
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
@ -169,12 +184,13 @@ def ui():
lambda x: params.update({"show_text": x}), show_text, None).then( lambda x: params.update({"show_text": x}), show_text, None).then(
toggle_text_in_history, gradio('history'), gradio('history')).then( toggle_text_in_history, gradio('history'), gradio('history')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then( chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
chat.redraw_html, shared.reload_inputs, gradio('display')) chat.redraw_html, gradio(ui_chat.reload_arr), gradio('display'))
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
activate.change(lambda x: params.update({'activate': x}), activate, None) activate.change(lambda x: params.update({'activate': x}), activate, None)
voice.change(lambda x: params.update({'selected_voice': x}), voice, None) voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
api_key.change(update_api_key, api_key, None) api_key.change(update_api_key, api_key, None)
model.change(lambda x: params.update({'model': x}), model, None)
# connect.click(check_valid_api, [], connection_status) # connect.click(check_valid_api, [], connection_status)
refresh.click(refresh_voices_dd, [], voice) refresh.click(refresh_voices_dd, [], voice)
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend

View file

@ -0,0 +1,139 @@
"""
An example of extension. It does nothing, but you can add transformations
before the return statements to customize the webui behavior.
Starting from history_modifier and ending in output_modifier, the
functions are declared in the same order that they are called at
generation time.
"""
import gradio as gr
import torch
from transformers import LogitsProcessor
from modules import chat, shared
from modules.text_generation import (
decode,
encode,
generate_reply,
)
params = {
"display_name": "Example Extension",
"is_tab": False,
}
class MyLogits(LogitsProcessor):
"""
Manipulates the probabilities for the next token before it gets sampled.
Used in the logits_processor_modifier function below.
"""
def __init__(self):
pass
def __call__(self, input_ids, scores):
# probs = torch.softmax(scores, dim=-1, dtype=torch.float)
# probs[0] /= probs[0].sum()
# scores = torch.log(probs / (1 - probs))
return scores
def history_modifier(history):
"""
Modifies the chat history.
Only used in chat mode.
"""
return history
def state_modifier(state):
"""
Modifies the state variable, which is a dictionary containing the input
values in the UI like sliders and checkboxes.
"""
return state
def chat_input_modifier(text, visible_text, state):
"""
Modifies the user input string in chat mode (visible_text).
You can also modify the internal representation of the user
input (text) to change how it will appear in the prompt.
"""
return text, visible_text
def input_modifier(string, state, is_chat=False):
"""
In default/notebook modes, modifies the whole prompt.
In chat mode, it is the same as chat_input_modifier but only applied
to "text", here called "string", and not to "visible_text".
"""
return string
def bot_prefix_modifier(string, state):
"""
Modifies the prefix for the next bot reply in chat mode.
By default, the prefix will be something like "Bot Name:".
"""
return string
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
"""
Modifies the input ids and embeds.
Used by the multimodal extension to put image embeddings in the prompt.
Only used by loaders that use the transformers library for sampling.
"""
return prompt, input_ids, input_embeds
def logits_processor_modifier(processor_list, input_ids):
"""
Adds logits processors to the list, allowing you to access and modify
the next token probabilities.
Only used by loaders that use the transformers library for sampling.
"""
processor_list.append(MyLogits())
return processor_list
def output_modifier(string, state, is_chat=False):
"""
Modifies the LLM output before it gets presented.
In chat mode, the modified version goes into history['visible'],
and the original version goes into history['internal'].
"""
return string
def custom_generate_chat_prompt(user_input, state, **kwargs):
"""
Replaces the function that generates the prompt from the chat history.
Only used in chat mode.
"""
result = chat.generate_chat_prompt(user_input, state, **kwargs)
return result
def custom_css():
"""
Returns a CSS string that gets appended to the CSS for the webui.
"""
return ''
def custom_js():
"""
Returns a javascript string that gets appended to the javascript
for the webui.
"""
return ''
def setup():
"""
Gets executed only once, when the extension is imported.
"""
pass
def ui():
"""
Gets executed when the UI is drawn. Custom gradio elements and
their corresponding event handlers should be defined here.
To learn about gradio components, check out the docs:
https://gradio.app/docs/
"""
pass

View file

@ -0,0 +1,33 @@
let gallery_element = document.getElementById('gallery-extension');
let chat_mode_element = document.getElementById('chat-mode');
let extensions_block = document.getElementById('extensions');
let extensions_block_size = extensions_block.childNodes.length;
let gallery_only = (extensions_block_size == 5);
document.querySelector('.header_bar').addEventListener('click', function(event) {
if (event.target.tagName === 'BUTTON') {
const buttonText = event.target.textContent.trim();
let chat_visible = (buttonText == 'Chat');
let default_visible = (buttonText == 'Default');
let notebook_visible = (buttonText == 'Notebook');
let chat_mode_visible = (chat_mode_element.offsetHeight > 0 && chat_mode_element.offsetWidth > 0);
// Only show this extension in the Chat tab
if (chat_visible) {
if (chat_mode_visible) {
gallery_element.style.display = 'block';
extensions_block.style.display = '';
} else {
gallery_element.style.display = 'none';
extensions_block.style.display = 'none';
}
} else {
gallery_element.style.display = 'none';
if (gallery_only) {
extensions_block.style.display = 'none';
}
}
}
});

View file

@ -82,8 +82,13 @@ def select_character(evt: gr.SelectData):
return (evt.value[1]) return (evt.value[1])
def custom_js():
path_to_js = Path(__file__).parent.resolve() / 'script.js'
return open(path_to_js, 'r').read()
def ui(): def ui():
with gr.Accordion("Character gallery", open=False): with gr.Accordion("Character gallery", open=False, elem_id='gallery-extension'):
update = gr.Button("Refresh") update = gr.Button("Refresh")
gr.HTML(value="<style>" + generate_css() + "</style>") gr.HTML(value="<style>" + generate_css() + "</style>")
gallery = gr.Dataset(components=[gr.HTML(visible=False)], gallery = gr.Dataset(components=[gr.HTML(visible=False)],

View file

@ -1,3 +1,5 @@
import html
import gradio as gr import gradio as gr
from deep_translator import GoogleTranslator from deep_translator import GoogleTranslator
@ -27,7 +29,8 @@ def output_modifier(string):
if not params['activate']: if not params['activate']:
return string return string
return GoogleTranslator(source='en', target=params['language string']).translate(string) translated_str = GoogleTranslator(source='en', target=params['language string']).translate(html.unescape(string))
return html.escape(translated_str)
def bot_prefix_modifier(string): def bot_prefix_modifier(string):

View file

@ -1,8 +0,0 @@
import gradio as gr
from modules.logging_colors import logger
def ui():
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
logger.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")

View file

@ -0,0 +1,143 @@
import torch
from modules import chat, shared
from modules.text_generation import (
decode,
encode,
generate_reply,
)
from transformers import LogitsProcessor
import gradio as gr
params = {
"display_name": "Long replies",
"is_tab": False,
"min_length": 120,
}
initial_size = 0
class MyLogits(LogitsProcessor):
"""
Manipulates the probabilities for the next token before it gets sampled.
Used in the logits_processor_modifier function below.
"""
def __init__(self):
self.newline_id = shared.tokenizer.encode('\n')[-1]
pass
def __call__(self, input_ids, scores):
if input_ids.shape[-1] - initial_size < params["min_length"]:
scores[...,self.newline_id] = -1000
# scores[...,shared.tokenizer.eos_token_id] = -1000
# probs = torch.softmax(scores, dim=-1, dtype=torch.float)
# probs[0] /= probs[0].sum()
# scores = torch.log(probs / (1 - probs))
return scores
def history_modifier(history):
"""
Modifies the chat history.
Only used in chat mode.
"""
return history
def state_modifier(state):
"""
Modifies the state variable, which is a dictionary containing the input
values in the UI like sliders and checkboxes.
"""
return state
def chat_input_modifier(text, visible_text, state):
"""
Modifies the user input string in chat mode (visible_text).
You can also modify the internal representation of the user
input (text) to change how it will appear in the prompt.
"""
return text, visible_text
def input_modifier(string, state):
"""
In default/notebook modes, modifies the whole prompt.
In chat mode, it is the same as chat_input_modifier but only applied
to "text", here called "string", and not to "visible_text".
"""
return string
def bot_prefix_modifier(string, state):
"""
Modifies the prefix for the next bot reply in chat mode.
By default, the prefix will be something like "Bot Name:".
"""
return string
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
"""
Modifies the input ids and embeds.
Used by the multimodal extension to put image embeddings in the prompt.
Only used by loaders that use the transformers library for sampling.
"""
global initial_size
initial_size = input_ids.shape[-1]
return prompt, input_ids, input_embeds
def logits_processor_modifier(processor_list, input_ids):
"""
Adds logits processors to the list, allowing you to access and modify
the next token probabilities.
Only used by loaders that use the transformers library for sampling.
"""
processor_list.append(MyLogits())
return processor_list
def output_modifier(string, state):
"""
Modifies the LLM output before it gets presented.
In chat mode, the modified version goes into history['visible'],
and the original version goes into history['internal'].
"""
return string
def custom_generate_chat_prompt(user_input, state, **kwargs):
"""
Replaces the function that generates the prompt from the chat history.
Only used in chat mode.
"""
result = chat.generate_chat_prompt(user_input, state, **kwargs)
return result
def custom_css():
"""
Returns a CSS string that gets appended to the CSS for the webui.
"""
return ''
def custom_js():
"""
Returns a javascript string that gets appended to the javascript
for the webui.
"""
return ''
def setup():
"""
Gets executed only once, when the extension is imported.
"""
pass
def ui():
"""
Gets executed when the UI is drawn. Custom gradio elements and
their corresponding event handlers should be defined here.
To learn about gradio components, check out the docs:
https://gradio.app/docs/
"""
min_length = gr.Slider(0, 800, step=10, value=params['min_length'], label='Minimum reply length')
min_length.change(lambda x: params.update({'min_length': x}), min_length, None)

View file

@ -11,10 +11,10 @@ https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b
To run this extension, download a LLM that supports multimodality, and then start server.py with the appropriate `--multimodal-pipeline` argument. Examples: To run this extension, download a LLM that supports multimodality, and then start server.py with the appropriate `--multimodal-pipeline` argument. Examples:
``` ```
python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b --chat python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b
python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b --chat python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b
python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b --chat python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b
python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b --chat python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b
``` ```
There is built-in support for LLaVA-v0-13B and LLaVA-v0-7b. To install `minigpt4`: There is built-in support for LLaVA-v0-13B and LLaVA-v0-7b. To install `minigpt4`:
@ -38,6 +38,8 @@ As of now, the following multimodal pipelines are supported:
|[LLaVA 7B](https://github.com/haotian-liu/LLaVA)|`llava-7b`|[LLaVA 7B](https://huggingface.co/wojtab/llava-7b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in| |[LLaVA 7B](https://github.com/haotian-liu/LLaVA)|`llava-7b`|[LLaVA 7B](https://huggingface.co/wojtab/llava-7b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in|
|[MiniGPT-4 7B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-7b`|[Vicuna v0 7B](https://huggingface.co/TheBloke/vicuna-7B-GPTQ-4bit-128g)|GPTQ 4-bit quant, new format|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)| |[MiniGPT-4 7B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-7b`|[Vicuna v0 7B](https://huggingface.co/TheBloke/vicuna-7B-GPTQ-4bit-128g)|GPTQ 4-bit quant, new format|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|[MiniGPT-4 13B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-13b`|[Vicuna v0 13B](https://huggingface.co/anon8231489123/vicuna-13b-GPTQ-4bit-128g)|GPTQ 4-bit quant, old CUDA|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)| |[MiniGPT-4 13B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-13b`|[Vicuna v0 13B](https://huggingface.co/anon8231489123/vicuna-13b-GPTQ-4bit-128g)|GPTQ 4-bit quant, old CUDA|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|[InstructBLIP 7B](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip)|`instructblip-7b`|[Vicuna v1.1 7B](https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g)|GPTQ 4-bit quant|[kjerk/instructblip-pipeline](https://github.com/kjerk/instructblip-pipeline)|
|[InstructBLIP 13B](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip)|`instructblip-13b`|[Vicuna v1.1 13B](https://huggingface.co/TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g)|GPTQ 4-bit quant|[kjerk/instructblip-pipeline](https://github.com/kjerk/instructblip-pipeline)|
Some pipelines could support different LLMs but do note that while it might work, it isn't a supported configuration. Some pipelines could support different LLMs but do note that while it might work, it isn't a supported configuration.

View file

@ -56,10 +56,13 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
@staticmethod @staticmethod
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor: def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
if hasattr(shared.model.model, 'embed_tokens'): for attr in ['', 'model', 'model.model', 'model.model.model']:
func = shared.model.model.embed_tokens tmp = getattr(shared.model, attr, None) if attr != '' else shared.model
if tmp is not None and hasattr(tmp, 'embed_tokens'):
func = tmp.embed_tokens
break
else: else:
func = shared.model.model.model.embed_tokens # AutoGPTQ case raise ValueError('The embed_tokens method has not been found for this loader.')
return func(input_ids).to(shared.model.device, dtype=shared.model.dtype) return func(input_ids).to(shared.model.device, dtype=shared.model.dtype)

View file

@ -35,6 +35,15 @@ input_hijack = {
multimodal_embedder: MultimodalEmbedder = None multimodal_embedder: MultimodalEmbedder = None
def chat_input_modifier(text, visible_text, state):
global input_hijack
if input_hijack['state']:
input_hijack['state'] = False
return input_hijack['value'](text, visible_text)
else:
return text, visible_text
def add_chat_picture(picture, text, visible_text): def add_chat_picture(picture, text, visible_text):
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable) # resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
max_hw, min_hw = max(picture.size), min(picture.size) max_hw, min_hw = max(picture.size), min(picture.size)

View file

@ -22,6 +22,7 @@ options = {
'session_metadata': 'text-generation-webui', 'session_metadata': 'text-generation-webui',
} }
def ui(): def ui():
settings = shared.settings.get("ngrok") settings = shared.settings.get("ngrok")
if settings: if settings:
@ -33,4 +34,3 @@ def ui():
logging.info(f"Ingress established at: {tunnel.url()}") logging.info(f"Ingress established at: {tunnel.url()}")
except ModuleNotFoundError: except ModuleNotFoundError:
logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`") logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`")

View file

@ -1,17 +1,16 @@
# An OpenedAI API (openai like) # An OpenedAI API (openai like)
This extension creates an API that works kind of like openai (ie. api.openai.com). This extension creates an API that works kind of like openai (ie. api.openai.com).
It's incomplete so far but perhaps is functional enough for you.
## Setup & installation ## Setup & installation
Optional (for flask_cloudflared, embeddings): Install the requirements:
``` ```
pip3 install -r requirements.txt pip3 install -r requirements.txt
``` ```
It listens on tcp port 5001 by default. You can use the OPENEDAI_PORT environment variable to change this. It listens on `tcp port 5001` by default. You can use the `OPENEDAI_PORT` environment variable to change this.
Make sure you enable it in server launch parameters, it should include: Make sure you enable it in server launch parameters, it should include:
@ -19,15 +18,44 @@ Make sure you enable it in server launch parameters, it should include:
--extensions openai --extensions openai
``` ```
You can also use the ``--listen`` argument to make the server available on the networ, and/or the ```--share``` argument to enable a public Cloudflare endpoint. You can also use the `--listen` argument to make the server available on the networ, and/or the `--share` argument to enable a public Cloudflare endpoint.
To enable the basic image generation support (txt2img) set the environment variable SD_WEBUI_URL to point to your Stable Diffusion API ([Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)). To enable the basic image generation support (txt2img) set the environment variable `SD_WEBUI_URL` to point to your Stable Diffusion API ([Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)).
For example: For example:
``` ```
SD_WEBUI_URL=http://127.0.0.1:7861 SD_WEBUI_URL=http://127.0.0.1:7861
``` ```
## Quick start
1. Install the requirements.txt (pip)
2. Enable the `openeai` module (--extensions openai), restart the server.
3. Configure the openai client
Most openai application can be configured to connect the API if you set the following environment variables:
```shell
# Sample .env file:
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
OPENAI_API_BASE=http://0.0.0.0:5001/v1
```
If needed, replace 0.0.0.0 with the IP/port of your server.
## Settings
To adjust your default settings, you can add the following to your `settings.yaml` file.
```
openai-port: 5002
openai-embedding_device: cuda
openai-sd_webui_url: http://127.0.0.1:7861
openai-debug: 1
```
### Models ### Models
This has been successfully tested with Alpaca, Koala, Vicuna, WizardLM and their variants, (ex. gpt4-x-alpaca, GPT4all-snoozy, stable-vicuna, wizard-vicuna, etc.) and many others. Models that have been trained for **Instruction Following** work best. If you test with other models please let me know how it goes. Less than satisfying results (so far) from: RWKV-4-Raven, llama, mpt-7b-instruct/chat. This has been successfully tested with Alpaca, Koala, Vicuna, WizardLM and their variants, (ex. gpt4-x-alpaca, GPT4all-snoozy, stable-vicuna, wizard-vicuna, etc.) and many others. Models that have been trained for **Instruction Following** work best. If you test with other models please let me know how it goes. Less than satisfying results (so far) from: RWKV-4-Raven, llama, mpt-7b-instruct/chat.
@ -36,9 +64,9 @@ For best results across all API endpoints, a model like [vicuna-13b-v1.3-GPTQ](h
For good results with the [Completions](https://platform.openai.com/docs/api-reference/completions) API endpoint, in addition to the above models, you can also try using a base model like [falcon-7b](https://huggingface.co/tiiuae/falcon-7b) or Llama. For good results with the [Completions](https://platform.openai.com/docs/api-reference/completions) API endpoint, in addition to the above models, you can also try using a base model like [falcon-7b](https://huggingface.co/tiiuae/falcon-7b) or Llama.
For good results with the [ChatCompletions](https://platform.openai.com/docs/api-reference/chat) or [Edits](https://platform.openai.com/docs/api-reference/edits) API endpoints you can use almost any model trained for instruction following - within the limits of the model. Be sure that the proper instruction template is detected and loaded or the results will not be good. For good results with the [ChatCompletions](https://platform.openai.com/docs/api-reference/chat) or [Edits](https://platform.openai.com/docs/api-reference/edits) API endpoints you can use almost any model trained for instruction following. Be sure that the proper instruction template is detected and loaded or the results will not be good.
For the proper instruction format to be detected you need to have a matching model entry in your ```models/config.yaml``` file. Be sure to keep this file up to date. For the proper instruction format to be detected you need to have a matching model entry in your `models/config.yaml` file. Be sure to keep this file up to date.
A matching instruction template file in the characters/instruction-following/ folder will loaded and applied to format messages correctly for the model - this is critical for good results. A matching instruction template file in the characters/instruction-following/ folder will loaded and applied to format messages correctly for the model - this is critical for good results.
For example, the Wizard-Vicuna family of models are trained with the Vicuna 1.1 format. In the models/config.yaml file there is this matching entry: For example, the Wizard-Vicuna family of models are trained with the Vicuna 1.1 format. In the models/config.yaml file there is this matching entry:
@ -49,7 +77,7 @@ For example, the Wizard-Vicuna family of models are trained with the Vicuna 1.1
instruction_template: 'Vicuna-v1.1' instruction_template: 'Vicuna-v1.1'
``` ```
This refers to ```characters/instruction-following/Vicuna-v1.1.yaml```, which looks like this: This refers to `characters/instruction-following/Vicuna-v1.1.yaml`, which looks like this:
``` ```
user: "USER:" user: "USER:"
@ -61,63 +89,66 @@ context: "A chat between a curious user and an artificial intelligence assistant
For most common models this is already setup, but if you are using a new or uncommon model you may need add a matching entry to the models/config.yaml and possibly create your own instruction-following template and for best results. For most common models this is already setup, but if you are using a new or uncommon model you may need add a matching entry to the models/config.yaml and possibly create your own instruction-following template and for best results.
If you see this in your logs, it probably means that the correct format could not be loaded: If you see this in your logs, it probably means that the correct format could not be loaded:
``` ```
Warning: Loaded default instruction-following template for model. Warning: Loaded default instruction-following template for model.
``` ```
### Embeddings (alpha) ### Embeddings (alpha)
Embeddings requires ```sentence-transformers``` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: ```sentence-transformers/all-mpnet-base-v2``` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default ```text-embedding-ada-002``` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future. Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default `text-embedding-ada-002` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
| model name | dimensions | input max tokens | speed | size | Avg. performance | | model name | dimensions | input max tokens | speed | size | Avg. performance |
| --- | --- | --- | --- | --- | --- | | ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- |
| text-embedding-ada-002 | 1536 | 8192 | - | - | - | | text-embedding-ada-002 | 1536 | 8192 | - | - | - |
| text-davinci-002 | 768 | 2046 | - | - | - | | text-davinci-002 | 768 | 2046 | - | - | - |
| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 | | all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 |
| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 | | all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 |
In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable OPENEDAI_EMBEDDING_MODEL, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2". In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable `OPENEDAI_EMBEDDING_MODEL`, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2".
Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable. Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable.
### Client Application Setup ### Client Application Setup
Almost everything you use it with will require you to set a dummy OpenAI API key environment variable. Almost everything you use it with will require you to set a dummy OpenAI API key environment variable.
With the [official python openai client](https://github.com/openai/openai-python), you can set the OPENAI_API_BASE environment variable before you import the openai module, like so: With the [official python openai client](https://github.com/openai/openai-python), set the `OPENAI_API_BASE` environment variables:
``` ```shell
# Sample .env file:
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111 OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
OPENAI_API_BASE=http://127.0.0.1:5001/v1 OPENAI_API_BASE=http://0.0.0.0:5001/v1
``` ```
If needed, replace 127.0.0.1 with the IP/port of your server. If needed, replace 0.0.0.0 with the IP/port of your server.
If using .env files to save the OPENAI_API_BASE and OPENAI_API_KEY variables, you can ensure compatibility by loading the .env file before loading the openai module, like so in python: If using .env files to save the `OPENAI_API_BASE` and `OPENAI_API_KEY` variables, make sure the .env file is loaded before the openai module is imported:
``` ```python
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv() # make sure the environment variables are set before import
import openai import openai
``` ```
With the [official Node.js openai client](https://github.com/openai/openai-node) it is slightly more more complex because the environment variables are not used by default, so small source code changes may be required to use the environment variables, like so: With the [official Node.js openai client](https://github.com/openai/openai-node) it is slightly more more complex because the environment variables are not used by default, so small source code changes may be required to use the environment variables, like so:
``` ```js
const openai = OpenAI(Configuration({ const openai = OpenAI(
Configuration({
apiKey: process.env.OPENAI_API_KEY, apiKey: process.env.OPENAI_API_KEY,
basePath: process.env.OPENAI_API_BASE, basePath: process.env.OPENAI_API_BASE
})); })
);
``` ```
For apps made with the [chatgpt-api Node.js client library](https://github.com/transitive-bullshit/chatgpt-api): For apps made with the [chatgpt-api Node.js client library](https://github.com/transitive-bullshit/chatgpt-api):
``` ```js
const api = new ChatGPTAPI({ const api = new ChatGPTAPI({
apiKey: process.env.OPENAI_API_KEY, apiKey: process.env.OPENAI_API_KEY,
apiBaseUrl: process.env.OPENAI_API_BASE, apiBaseUrl: process.env.OPENAI_API_BASE
}) });
``` ```
## API Documentation & Examples ## API Documentation & Examples
@ -127,89 +158,81 @@ The OpenAI API is well documented, you can view the documentation here: https://
Examples of how to use the Completions API in Python can be found here: https://platform.openai.com/examples Examples of how to use the Completions API in Python can be found here: https://platform.openai.com/examples
Not all of them will work with all models unfortunately, See the notes on Models for how to get the best results. Not all of them will work with all models unfortunately, See the notes on Models for how to get the best results.
Here is a simple python example of how you can use the Edit endpoint as a translator. Here is a simple python example.
```python ```python
import os
os.environ['OPENAI_API_KEY']="sk-111111111111111111111111111111111111111111111111"
os.environ['OPENAI_API_BASE']="http://0.0.0.0:5001/v1"
import openai import openai
response = openai.Edit.create(
response = openai.ChatCompletion.create(
model="x", model="x",
instruction="Translate this into French", messages = [{ 'role': 'system', 'content': "Answer in a consistent style." },
input="Our mission is to ensure that artificial general intelligence benefits all of humanity.", {'role': 'user', 'content': "Teach me about patience."},
{'role': 'assistant', 'content': "The river that carves the deepest valley flows from a modest spring; the grandest symphony originates from a single note; the most intricate tapestry begins with a solitary thread."},
{'role': 'user', 'content': "Teach me about the ocean."},
]
) )
print(response['choices'][0]['text']) text = response['choices'][0]['message']['content']
# Sample Output: print(text)
# Notre mission est de garantir que l'intelligence artificielle généralisée profite à tous les membres de l'humanité.
``` ```
## Compatibility & not so compatibility ## Compatibility & not so compatibility
| API endpoint | tested with | notes | | API endpoint | tested with | notes |
| --- | --- | --- | | ------------------------- | ---------------------------------- | --------------------------------------------------------------------------- |
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options | | /v1/chat/completions | openai.ChatCompletion.create() | Use it with instruction following models |
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for, model does nothing yet anyways | | /v1/embeddings | openai.Embedding.create() | Using SentenceTransformer embeddings |
| /v1/text_completion | openai.Completion.create() | the most tested, only supports single string input so far, variable quality based on the model |
| /v1/chat/completions | openai.ChatCompletion.create() | Quality depends a lot on the model |
| /v1/edits | openai.Edit.create() | Works the best of all, perfect for instruction following models |
| /v1/images/generations | openai.Image.create() | Bare bones, no model configuration, response_format='b64_json' only. | | /v1/images/generations | openai.Image.create() | Bare bones, no model configuration, response_format='b64_json' only. |
| /v1/embeddings | openai.Embedding.create() | Using Sentence Transformer, dimensions are different and may never be directly comparable to openai embeddings. | | /v1/moderations | openai.Moderation.create() | Basic initial support via embeddings |
| /v1/moderations | openai.Moderation.create() | does nothing. successfully. | | /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
| /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models |
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model |
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) | | /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
| /v1/engines/*/embeddings | python-openai v0.25 | Legacy endpoint | | /v1/engines/\*/embeddings | python-openai v0.25 | Legacy endpoint |
| /v1/engines/*/generate | openai engines.generate | Legacy endpoint | | /v1/engines/\*/generate | openai engines.generate | Legacy endpoint |
| /v1/engines | openai engines.list | Legacy Lists models | | /v1/engines | openai engines.list | Legacy Lists models |
| /v1/engines/{model_name} | openai engines.get -i {model_name} | You can use this legacy endpoint to load models via the api | | /v1/engines/{model_name} | openai engines.get -i {model_name} | You can use this legacy endpoint to load models via the api or command line |
| /v1/images/edits | openai.Image.create_edit() | not yet supported | | /v1/images/edits | openai.Image.create_edit() | not yet supported |
| /v1/images/variations | openai.Image.create_variation() | not yet supported | | /v1/images/variations | openai.Image.create_variation() | not yet supported |
| /v1/audio/\* | openai.Audio.\* | not yet supported | | /v1/audio/\* | openai.Audio.\* | supported |
| /v1/files\* | openai.Files.\* | not yet supported | | /v1/files\* | openai.Files.\* | not yet supported |
| /v1/fine-tunes\* | openai.FineTune.\* | not yet supported | | /v1/fine-tunes\* | openai.FineTune.\* | not yet supported |
| /v1/search | openai.search, engines.search | not yet supported | | /v1/search | openai.search, engines.search | not yet supported |
The model name setting is ignored in completions, but you may need to adjust the maximum token length to fit the model (ie. set to <2048 tokens instead of 4096, 8k, etc). To mitigate some of this, the max_tokens value is halved until it is less than truncation_length for the model (typically 2k). Because of the differences in OpenAI model context sizes (2k, 4k, 8k, 16k, etc,) you may need to adjust the max_tokens to fit into the context of the model you choose.
Streaming, temperature, top_p, max_tokens, stop, should all work as expected, but not all parameters are mapped correctly. Streaming, temperature, top_p, max_tokens, stop, should all work as expected, but not all parameters are mapped correctly.
Some hacky mappings: Some hacky mappings:
| OpenAI | text-generation-webui | note | | OpenAI | text-generation-webui | note |
| --- | --- | --- | | ----------------------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| model | - | Ignored, the model is not changed |
| frequency_penalty | encoder_repetition_penalty | this seems to operate with a different scale and defaults, I tried to scale it based on range & defaults, but the results are terrible. hardcoded to 1.18 until there is a better way | | frequency_penalty | encoder_repetition_penalty | this seems to operate with a different scale and defaults, I tried to scale it based on range & defaults, but the results are terrible. hardcoded to 1.18 until there is a better way |
| presence_penalty | repetition_penalty | same issues as frequency_penalty, hardcoded to 1.0 | | presence_penalty | repetition_penalty | same issues as frequency_penalty, hardcoded to 1.0 |
| best_of | top_k | default is 1 | | best_of | top_k | default is 1 (top_k is 20 for chat, which doesn't support best_of) |
| stop | custom_stopping_strings | this is also stuffed with ['\n###', "\n{user prompt}", "{user prompt}" ] for good measure. |
| n | 1 | variations are not supported yet. | | n | 1 | variations are not supported yet. |
| 1 | num_beams | hardcoded to 1 | | 1 | num_beams | hardcoded to 1 |
| 1.0 | typical_p | hardcoded to 1.0 | | 1.0 | typical_p | hardcoded to 1.0 |
| max_tokens | max_new_tokens | For Text Completions max_tokens is set smaller than the truncation_length minus the prompt length. This can cause no input to be generated if the prompt is too large. For ChatCompletions, the older chat messages may be dropped to fit the max_new_tokens requested | | logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' |
| logprobs | - | not supported yet |
| logit_bias | - | not supported yet |
| messages.name | - | not supported yet | | messages.name | - | not supported yet |
| suffix | - | not supported yet |
| user | - | not supported yet | | user | - | not supported yet |
| functions/function_call | - | function calls are not supported yet | | functions/function_call | - | function calls are not supported yet |
defaults are mostly from openai, so are different. I use the openai defaults where I can and try to scale them to the webui defaults with the same intent.
### Applications ### Applications
Almost everything needs the OPENAI_API_KEY environment variable set, for example: Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment variable set, but there are some exceptions.
```
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
```
Some apps are picky about key format, but 'dummy' or 'sk-dummy' also work in most cases.
Most application will work if you also set:
```
OPENAI_API_BASE=http://127.0.0.1:5001/v1
```
but there are some exceptions.
| Compatibility | Application/Library | url | notes / setting | | Compatibility | Application/Library | Website | Notes |
| --- | --- | --- | --- | | ------------- | ---------------------- | ------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| ✅❌ | openai-python (v0.25+) | https://github.com/openai/openai-python | only the endpoints from above are working. OPENAI_API_BASE=http://127.0.0.1:5001/v1 | | ✅❌ | openai-python (v0.25+) | https://github.com/openai/openai-python | only the endpoints from above are working. OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
| ✅❌ | openai-node | https://github.com/openai/openai-node | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) | | ✅❌ | openai-node | https://github.com/openai/openai-node | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) |
| ✅❌ | chatgpt-api | https://github.com/transitive-bullshit/chatgpt-api | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) | | ✅❌ | chatgpt-api | https://github.com/transitive-bullshit/chatgpt-api | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) |
| ✅ | anse | https://github.com/anse-app/anse | API Key & URL configurable in UI | | ✅ | anse | https://github.com/anse-app/anse | API Key & URL configurable in UI, Images also work |
| ✅ | shell_gpt | https://github.com/TheR1D/shell_gpt | OPENAI_API_HOST=http://127.0.0.1:5001 | | ✅ | shell_gpt | https://github.com/TheR1D/shell_gpt | OPENAI_API_HOST=http://127.0.0.1:5001 |
| ✅ | gpt-shell | https://github.com/jla/gpt-shell | OPENAI_API_BASE=http://127.0.0.1:5001/v1 | | ✅ | gpt-shell | https://github.com/jla/gpt-shell | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
| ✅ | gpt-discord-bot | https://github.com/openai/gpt-discord-bot | OPENAI_API_BASE=http://127.0.0.1:5001/v1 | | ✅ | gpt-discord-bot | https://github.com/openai/gpt-discord-bot | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
@ -218,15 +241,16 @@ but there are some exceptions.
| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. | | ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. |
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context | | ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 | | ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
| ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported |
## Future plans ## Future plans
* better error handling
* model changing, esp. something for swapping loras or embedding models - better error handling
* consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard) - model changing, esp. something for swapping loras or embedding models
* do something about rate limiting or locking requests for completions, most systems will only be able handle a single request at a time before OOM - consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard)
## Bugs? Feedback? Comments? Pull requests? ## Bugs? Feedback? Comments? Pull requests?
To enable debugging and get copious output you can set the OPENEDAI_DEBUG=1 environment variable. To enable debugging and get copious output you can set the `OPENEDAI_DEBUG=1` environment variable.
Are all appreciated, please @matatonic and I'll try to get back to you as soon as possible. Are all appreciated, please @matatonic and I'll try to get back to you as soon as possible.

View file

@ -3,6 +3,10 @@
# Dockerfile: # Dockerfile:
# ENV OPENEDAI_EMBEDDING_MODEL=all-mpnet-base-v2 # Optional # ENV OPENEDAI_EMBEDDING_MODEL=all-mpnet-base-v2 # Optional
# RUN python3 cache_embedded_model.py # RUN python3 cache_embedded_model.py
import os, sentence_transformers import os
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
import sentence_transformers
from extensions.openai.script import params
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", params.get('embedding_model', 'all-mpnet-base-v2'))
model = sentence_transformers.SentenceTransformer(st_model) model = sentence_transformers.SentenceTransformer(st_model)

View file

@ -0,0 +1,637 @@
import time
import tiktoken
import torch
import torch.nn.functional as F
import yaml
from extensions.openai.defaults import clamp, default, get_default_req_params
from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg, end_line
from modules import shared
from modules.text_generation import decode, encode, generate_reply
from transformers import LogitsProcessor, LogitsProcessorList
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
class LogitsBiasProcessor(LogitsProcessor):
def __init__(self, logit_bias={}):
self.logit_bias = logit_bias
if self.logit_bias:
self.keys = list([int(key) for key in self.logit_bias.keys()])
values = [self.logit_bias[str(key)] for key in self.keys]
self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device)
debug_msg(f"{self})")
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logit_bias:
debug_msg(logits[0, self.keys], " + ", self.values)
logits[0, self.keys] += self.values
debug_msg(" --> ", logits[0, self.keys])
debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0])))
return logits
def __repr__(self):
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>"
class LogprobProcessor(LogitsProcessor):
def __init__(self, logprobs=None):
self.logprobs = logprobs
self.token_alternatives = {}
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logprobs is not None: # 0-5
log_e_probabilities = F.log_softmax(logits, dim=1)
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_tokens = [decode(tok) for tok in top_indices[0]]
top_probs = [float(x) for x in top_values[0]]
self.token_alternatives = dict(zip(top_tokens, top_probs))
debug_msg(repr(self))
return logits
def __repr__(self):
return f"<{self.__class__.__name__}(logprobs={self.logprobs}, token_alternatives={self.token_alternatives})>"
def convert_logprobs_to_tiktoken(model, logprobs):
# more problems than it's worth.
# try:
# encoder = tiktoken.encoding_for_model(model)
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
# except KeyError:
# # assume native tokens if we can't find the tokenizer
# return logprobs
return logprobs
def marshal_common_params(body):
# Request Parameters
# Try to use openai defaults or map them to something with the same intent
req_params = get_default_req_params()
# Common request parameters
req_params['truncation_length'] = shared.settings['truncation_length']
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
# OpenAI API Parameters
# model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this
req_params['requested_model'] = body.get('model', shared.model_name)
req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.01, 1.99) # fixup absolute 0.0/2.0
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.01, 1.0)
n = default(body, 'n', 1)
if n != 1:
raise InvalidRequestError(message="Only n = 1 is supported.", param='n')
if 'stop' in body: # str or array, max len 4 (ignored)
if isinstance(body['stop'], str):
req_params['stopping_strings'] = [body['stop']] # non-standard parameter
elif isinstance(body['stop'], list):
req_params['stopping_strings'] = body['stop']
# presence_penalty - ignored
# frequency_penalty - ignored
# pass through unofficial params
req_params['repetition_penalty'] = default(body, 'repetition_penalty', req_params['repetition_penalty'])
req_params['encoder_repetition_penalty'] = default(body, 'encoder_repetition_penalty', req_params['encoder_repetition_penalty'])
# user - ignored
logits_processor = []
logit_bias = body.get('logit_bias', None)
if logit_bias: # {str: float, ...}
# XXX convert tokens from tiktoken based on requested model
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
try:
encoder = tiktoken.encoding_for_model(req_params['requested_model'])
new_logit_bias = {}
for logit, bias in logit_bias.items():
for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]:
if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens
continue
new_logit_bias[str(int(x))] = bias
debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias)
logit_bias = new_logit_bias
except KeyError:
pass # assume native tokens if we can't find the tokenizer
logits_processor = [LogitsBiasProcessor(logit_bias)]
logprobs = None # coming to chat eventually
if 'logprobs' in body:
logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5.
req_params['logprob_proc'] = LogprobProcessor(logprobs)
logits_processor.extend([req_params['logprob_proc']])
else:
logprobs = None
if logits_processor: # requires logits_processor support
req_params['logits_processor'] = LogitsProcessorList(logits_processor)
return req_params
def messages_to_prompt(body: dict, req_params: dict, max_tokens):
# functions
if body.get('functions', []): # chat only
raise InvalidRequestError(message="functions is not supported.", param='functions')
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
if 'messages' not in body:
raise InvalidRequestError(message="messages is required", param='messages')
messages = body['messages']
role_formats = {
'user': 'User: {message}\n',
'assistant': 'Assistant: {message}\n',
'system': '{message}',
'context': 'You are a helpful assistant. Answer as concisely as possible.\nUser: I want your assistance.\nAssistant: Sure! What can I do for you?',
'prompt': 'Assistant:',
}
if 'stopping_strings' not in req_params:
req_params['stopping_strings'] = []
# Instruct models can be much better
if shared.settings['instruction_template']:
try:
instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r'))
template = instruct['turn_template']
system_message_template = "{message}"
system_message_default = instruct.get('context', '') # can be missing
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', ''))
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', ''))
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
role_formats = {
'user': user_message_template,
'assistant': bot_message_template,
'system': system_message_template,
'context': system_message_default,
'prompt': bot_prompt,
}
if 'Alpaca' in shared.settings['instruction_template']:
req_params['stopping_strings'].extend(['\n###'])
elif instruct['user']: # WizardLM and some others have no user prompt.
req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}")
except Exception as e:
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}")
print("Warning: Loaded default instruction-following template for model.")
else:
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
print("Warning: Loaded default instruction-following template for model.")
system_msgs = []
chat_msgs = []
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
context_msg = end_line(context_msg)
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
if 'prompt' in body:
context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg
for m in messages:
if 'role' not in m:
raise InvalidRequestError(message="messages: missing role", param='messages')
if 'content' not in m:
raise InvalidRequestError(message="messages: missing content", param='messages')
role = m['role']
content = m['content']
# name = m.get('name', None)
# function_call = m.get('function_call', None) # user name or function name with output in content
msg = role_formats[role].format(message=content)
if role == 'system':
system_msgs.extend([msg])
elif role == 'function':
raise InvalidRequestError(message="role: function is not supported.", param='messages')
else:
chat_msgs.extend([msg])
system_msg = '\n'.join(system_msgs)
system_msg = end_line(system_msg)
prompt = system_msg + context_msg + ''.join(chat_msgs) + role_formats['prompt']
token_count = len(encode(prompt)[0])
if token_count >= req_params['truncation_length']:
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens."
raise InvalidRequestError(message=err_msg, param='messages')
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
print(f"Warning: ${err_msg}")
# raise InvalidRequestError(message=err_msg, params='max_tokens')
return prompt, token_count
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
# Chat Completions
object_type = 'chat.completions'
created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# common params
req_params = marshal_common_params(body)
req_params['stream'] = False
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
# chat default max_tokens is 'inf', but also flexible
max_tokens = 0
max_tokens_str = 'length' if is_legacy else 'max_tokens'
if max_tokens_str in body:
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
req_params['max_new_tokens'] = max_tokens
else:
req_params['max_new_tokens'] = req_params['truncation_length']
# format the prompt from messages
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
# set real max, avoid deeper errors
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
stopping_strings = req_params.pop('stopping_strings', [])
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
for a in generator:
answer = a
# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
stop_reason = "length"
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name, # TODO: add Lora info?
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"message": {"role": "assistant", "content": answer}
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# resp[resp_list][0]["logprobs"] = None
return resp
# generator
def stream_chat_completions(body: dict, is_legacy: bool = False):
# Chat Completions
stream_object_type = 'chat.completions.chunk'
created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# common params
req_params = marshal_common_params(body)
req_params['stream'] = True
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
# chat default max_tokens is 'inf', but also flexible
max_tokens = 0
max_tokens_str = 'length' if is_legacy else 'max_tokens'
if max_tokens_str in body:
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
req_params['max_new_tokens'] = max_tokens
else:
req_params['max_new_tokens'] = req_params['truncation_length']
# format the prompt from messages
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
# set real max, avoid deeper errors
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
def chat_streaming_chunk(content):
# begin streaming
chunk = {
"id": cmpl_id,
"object": stream_object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": None,
# So yeah... do both methods? delta and messages.
"message": {'role': 'assistant', 'content': content},
"delta": {'role': 'assistant', 'content': content},
}],
}
if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# chunk[resp_list][0]["logprobs"] = None
return chunk
yield chat_streaming_chunk('')
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
seen_content = ''
completion_token_count = 0
for a in generator:
answer = a
len_seen = len(seen_content)
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
# strip extra leading space off new generated content
if len_seen == 0 and new_content[0] == ' ':
new_content = new_content[1:]
chunk = chat_streaming_chunk(new_content)
yield chunk
# to get the correct token_count, strip leading space if present
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
stop_reason = "length"
chunk = chat_streaming_chunk('')
chunk[resp_list][0]['finish_reason'] = stop_reason
chunk['usage'] = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
yield chunk
def completions(body: dict, is_legacy: bool = False):
# Legacy
# Text Completions
object_type = 'text_completion'
created_time = int(time.time())
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = 'context' if is_legacy else 'prompt'
if prompt_str not in body:
raise InvalidRequestError("Missing required input", param=prompt_str)
prompt_arg = body[prompt_str]
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
prompt_arg = [prompt_arg]
# common params
req_params = marshal_common_params(body)
req_params['stream'] = False
max_tokens_str = 'length' if is_legacy else 'max_tokens'
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
req_params['max_new_tokens'] = max_tokens
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', [])
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
resp_list_data = []
total_completion_token_count = 0
total_prompt_token_count = 0
for idx, prompt in enumerate(prompt_arg, start=0):
if isinstance(prompt[0], int):
# token lists
if requested_model == shared.model_name:
prompt = decode(prompt)[0]
else:
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
for a in generator:
answer = a
# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
respi = {
"index": idx,
"finish_reason": stop_reason,
"text": answer,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}
resp_list_data.extend([respi])
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name, # TODO: add Lora info?
resp_list: resp_list_data,
"usage": {
"prompt_tokens": total_prompt_token_count,
"completion_tokens": total_completion_token_count,
"total_tokens": total_prompt_token_count + total_completion_token_count
}
}
return resp
# generator
def stream_completions(body: dict, is_legacy: bool = False):
# Legacy
# Text Completions
# object_type = 'text_completion'
stream_object_type = 'text_completion.chunk'
created_time = int(time.time())
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = 'context' if is_legacy else 'prompt'
if prompt_str not in body:
raise InvalidRequestError("Missing required input", param=prompt_str)
prompt = body[prompt_str]
req_params = marshal_common_params(body)
requested_model = req_params.pop('requested_model')
if isinstance(prompt, list):
if prompt and isinstance(prompt[0], int):
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
else:
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
# common params
req_params['stream'] = True
max_tokens_str = 'length' if is_legacy else 'max_tokens'
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
req_params['max_new_tokens'] = max_tokens
logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', [])
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
token_count = len(encode(prompt)[0])
if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
def text_streaming_chunk(content):
# begin streaming
chunk = {
"id": cmpl_id,
"object": stream_object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": None,
"text": content,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}],
}
return chunk
yield text_streaming_chunk('')
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
seen_content = ''
completion_token_count = 0
for a in generator:
answer = a
len_seen = len(seen_content)
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
# strip extra leading space off new generated content
if len_seen == 0 and new_content[0] == ' ':
new_content = new_content[1:]
chunk = text_streaming_chunk(new_content)
yield chunk
# to get the correct count, we strip the leading space if present
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
chunk = text_streaming_chunk('')
chunk[resp_list][0]["finish_reason"] = stop_reason
chunk["usage"] = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
yield chunk

View file

@ -0,0 +1,73 @@
import copy
# Slightly different defaults for OpenAI's API
# Data type is important, Ex. use 0.0 for a float 0
default_req_params = {
'max_new_tokens': 16, # 'Inf' for chat
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'temperature': 1.0,
'top_p': 1.0,
'top_k': 1, # choose 20 for chat in absence of another default
'repetition_penalty': 1.18,
'repetition_penalty_range': 0,
'encoder_repetition_penalty': 1.0,
'suffix': None,
'stream': False,
'echo': False,
'seed': -1,
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
'truncation_length': 2048, # first use shared.settings value
'add_bos_token': True,
'do_sample': True,
'typical_p': 1.0,
'epsilon_cutoff': 0.0, # In units of 1e-4
'eta_cutoff': 0.0, # In units of 1e-4
'tfs': 1.0,
'top_a': 0.0,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0.0,
'length_penalty': 1.0,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
'guidance_scale': 1,
'negative_prompt': '',
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'custom_stopping_strings': '',
# 'logits_processor' - conditionally passed
# 'stopping_strings' - temporarily used
# 'logprobs' - temporarily used
# 'requested_model' - temporarily used
}
def get_default_req_params():
return copy.deepcopy(default_req_params)
def default(dic, key, default):
'''
little helper to get defaults if arg is present but None and should be the same type as default.
'''
val = dic.get(key, default)
if not isinstance(val, type(default)):
# maybe it's just something like 1 instead of 1.0
try:
v = type(default)(val)
if type(val)(v) == val: # if it's the same value passed in, it's ok.
return v
except:
pass
val = default
return val
def clamp(value, minvalue, maxvalue):
return max(minvalue, min(value, maxvalue))

Some files were not shown because too many files have changed in this diff Show more