diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..51e26b1 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,3 @@ +## Checklist: + +- [ ] I have read the [Contributing guidelines](https://github.com/oobabooga/text-generation-webui/wiki/Contributing-guidelines). diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index ce603a4..2de6d95 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,8 +13,8 @@ jobs: - uses: actions/stale@v5 with: stale-issue-message: "" - close-issue-message: "This issue has been closed due to inactivity for 30 days. If you believe it is still relevant, please leave a comment below." - days-before-issue-stale: 30 + close-issue-message: "This issue has been closed due to inactivity for 6 weeks. If you believe it is still relevant, please leave a comment below. You can tag a developer in your comment." + days-before-issue-stale: 42 days-before-issue-close: 0 stale-issue-label: "stale" days-before-pr-stale: -1 diff --git a/README.md b/README.md index 3a76a0c..11bd0e4 100644 --- a/README.md +++ b/README.md @@ -4,30 +4,28 @@ # Text generation web UI -A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, Pythia, OPT, and GALACTICA. +A Gradio web UI for Large Language Models. Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation. -|![Image1](https://github.com/oobabooga/screenshots/raw/main/qa.png) | ![Image2](https://github.com/oobabooga/screenshots/raw/main/cai3.png) | +|![Image1](https://github.com/oobabooga/screenshots/raw/main/print_instruct.png) | ![Image2](https://github.com/oobabooga/screenshots/raw/main/print_chat.png) | |:---:|:---:| -|![Image3](https://github.com/oobabooga/screenshots/raw/main/gpt4chan.png) | ![Image4](https://github.com/oobabooga/screenshots/raw/main/galactica.png) | +|![Image1](https://github.com/oobabooga/screenshots/raw/main/print_default.png) | ![Image2](https://github.com/oobabooga/screenshots/raw/main/print_parameters.png) | ## Features -* 3 interface modes: default, notebook, and chat -* Multiple model backends: tranformers, llama.cpp, AutoGPTQ, GPTQ-for-LLaMa, ExLlama, RWKV, FlexGen +* 3 interface modes: default (two columns), notebook, and chat +* Multiple model backends: [transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp), [ExLlama](https://github.com/turboderp/exllama), [ExLlamaV2](https://github.com/turboderp/exllamav2), [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ), [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa), [CTransformers](https://github.com/marella/ctransformers) * Dropdown menu for quickly switching between different models -* LoRA: load and unload LoRAs on the fly, load multiple LoRAs at the same time, train a new LoRA -* Precise instruction templates for chat mode, including Alpaca, Vicuna, Open Assistant, Dolly, Koala, ChatGLM, MOSS, RWKV-Raven, Galactica, StableLM, WizardLM, Baize, Ziya, Chinese-Vicuna, MPT, INCITE, Wizard Mega, KoAlpaca, Vigogne, Bactrian, h2o, and OpenBuddy +* LoRA: load and unload LoRAs on the fly, train a new LoRA using QLoRA +* Precise instruction templates for chat mode, including Llama-2-chat, Alpaca, Vicuna, WizardLM, StableLM, and many others +* 4-bit, 8-bit, and CPU inference through the transformers library +* Use llama.cpp models with transformers samplers (`llamacpp_HF` loader) * [Multimodal pipelines, including LLaVA and MiniGPT-4](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal) -* 8-bit and 4-bit inference through bitsandbytes -* CPU mode for transformers models -* [DeepSpeed ZeRO-3 inference](docs/DeepSpeed.md) -* [Extensions](docs/Extensions.md) +* [Extensions framework](docs/Extensions.md) * [Custom chat characters](docs/Chat-mode.md) * Very efficient text streaming * Markdown output with LaTeX rendering, to use for instance with [GALACTICA](https://github.com/paperswithcode/galai) -* Nice HTML output for GPT-4chan * API, including endpoints for websocket streaming ([see the examples](https://github.com/oobabooga/text-generation-webui/blob/main/api-examples)) To learn how to use the various features, check out the Documentation: https://github.com/oobabooga/text-generation-webui/tree/main/docs @@ -42,26 +40,24 @@ To learn how to use the various features, check out the Documentation: https://g Just download the zip above, extract it, and double-click on "start". The web UI and all its dependencies will be installed in the same folder. -* The source codes are here: https://github.com/oobabooga/one-click-installers +* The source codes and more information can be found here: https://github.com/oobabooga/one-click-installers * There is no need to run the installers as admin. -* AMD doesn't work on Windows. * Huge thanks to [@jllllll](https://github.com/jllllll), [@ClayShoaf](https://github.com/ClayShoaf), and [@xNul](https://github.com/xNul) for their contributions to these installers. ### Manual installation using Conda -Recommended if you have some experience with the command line. +Recommended if you have some experience with the command-line. #### 0. Install Conda https://docs.conda.io/en/latest/miniconda.html -On Linux or WSL, it can be automatically installed with these two commands: +On Linux or WSL, it can be automatically installed with these two commands ([source](https://educe-ubc.github.io/conda.html)): ``` curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh" bash Miniconda3.sh ``` -Source: https://educe-ubc.github.io/conda.html #### 1. Create a new conda environment @@ -75,17 +71,14 @@ conda activate textgen | System | GPU | Command | |--------|---------|---------| | Linux/WSL | NVIDIA | `pip3 install torch torchvision torchaudio` | +| Linux/WSL | CPU only | `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu` | | Linux | AMD | `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2` | -| MacOS + MPS (untested) | Any | `pip3 install torch torchvision torchaudio` | +| MacOS + MPS | Any | `pip3 install torch torchvision torchaudio` | | Windows | NVIDIA | `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117` | +| Windows | CPU only | `pip3 install torch torchvision torchaudio` | The up-to-date commands can be found here: https://pytorch.org/get-started/locally/. -#### 2.1 Special instructions - -* MacOS users: https://github.com/oobabooga/text-generation-webui/pull/393 -* AMD users: https://rentry.org/eq3hg - #### 3. Install the web UI ``` @@ -94,13 +87,30 @@ cd text-generation-webui pip install -r requirements.txt ``` -#### llama.cpp with GPU acceleration +#### AMD, Metal, Intel Arc, and CPUs without AVX2 -Requires the additional compilation step described here: [GPU acceleration](https://github.com/oobabooga/text-generation-webui/blob/main/docs/llama.cpp-models.md#gpu-acceleration). +1) Replace the last command above with -#### bitsandbytes +``` +pip install -r requirements_nocuda.txt +``` -bitsandbytes >= 0.39 may not work on older NVIDIA GPUs. In that case, to use `--load-in-8bit`, you may have to downgrade like this: +2) Manually install llama-cpp-python using the appropriate command for your hardware: [Installation from PyPI](https://github.com/abetlen/llama-cpp-python#installation-from-pypi). + +3) Do the same for CTransformers: [Installation](https://github.com/marella/ctransformers#installation). + +4) AMD: Manually install AutoGPTQ: [Installation](https://github.com/PanQiWei/AutoGPTQ#installation). + +5) AMD: Manually install [ExLlama](https://github.com/turboderp/exllama) by simply cloning it into the `repositories` folder (it will be automatically compiled at runtime after that): + +``` +cd text-generation-webui +git clone https://github.com/turboderp/exllama repositories/exllama +``` + +#### bitsandbytes on older NVIDIA GPUs + +bitsandbytes >= 0.39 may not work. In that case, to use `--load-in-8bit`, you may have to downgrade like this: * Linux: `pip install bitsandbytes==0.38.1` * Windows: `pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl` @@ -119,37 +129,48 @@ docker compose up --build ### Updating the requirements -From time to time, the `requirements.txt` changes. To update, use this command: +From time to time, the `requirements.txt` changes. To update, use these commands: ``` conda activate textgen cd text-generation-webui pip install -r requirements.txt --upgrade ``` + ## Downloading models -Models should be placed inside the `models/` folder. +Models should be placed in the `text-generation-webui/models` folder. They are usually downloaded from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads). -[Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads) is the main place to download models. These are some examples: +* Transformers or GPTQ models are made of several files and must be placed in a subfolder. Example: -* [Pythia](https://huggingface.co/models?sort=downloads&search=eleutherai%2Fpythia+deduped) -* [OPT](https://huggingface.co/models?search=facebook/opt) -* [GALACTICA](https://huggingface.co/models?search=facebook/galactica) -* [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B/tree/main) +``` +text-generation-webui +├── models +│   ├── lmsys_vicuna-33b-v1.3 +│   │   ├── config.json +│   │   ├── generation_config.json +│   │   ├── pytorch_model-00001-of-00007.bin +│   │   ├── pytorch_model-00002-of-00007.bin +│   │   ├── pytorch_model-00003-of-00007.bin +│   │   ├── pytorch_model-00004-of-00007.bin +│   │   ├── pytorch_model-00005-of-00007.bin +│   │   ├── pytorch_model-00006-of-00007.bin +│   │   ├── pytorch_model-00007-of-00007.bin +│   │   ├── pytorch_model.bin.index.json +│   │   ├── special_tokens_map.json +│   │   ├── tokenizer_config.json +│   │   └── tokenizer.model +``` -You can automatically download a model from HF using the script `download-model.py`: +* GGUF models are a single file and should be placed directly into `models`. Example: - python download-model.py organization/model +``` +text-generation-webui +├── models +│   ├── llama-2-13b-chat.Q4_K_M.gguf +``` -For example: - - python download-model.py facebook/opt-1.3b - -To download a protected model, set env vars `HF_USER` and `HF_PASS` to your Hugging Face username and password (or [User Access Token](https://huggingface.co/settings/tokens)). The model's terms must first be accepted on the HF website. - -#### GGML models - -You can drop these directly into the `models/` folder, making sure that the file name contains `ggml` somewhere and ends in `.bin`. +In both cases, you can use the "Model" tab of the UI to download the model from Hugging Face automatically. It is also possible to download via the command-line with `python download-model.py organization/model` (use `--help` to see all the options). #### GPT-4chan @@ -175,7 +196,10 @@ After downloading the model, follow these steps: python download-model.py EleutherAI/gpt-j-6B --text-only ``` -When you load this model in default or notebook modes, the "HTML" tab will show the generated text in 4chan format. +When you load this model in default or notebook modes, the "HTML" tab will show the generated text in 4chan format: + +![Image3](https://github.com/oobabooga/screenshots/raw/main/gpt4chan.png) + ## Starting the web UI @@ -195,8 +219,6 @@ Optionally, you can use the following command-line flags: | Flag | Description | |--------------------------------------------|-------------| | `-h`, `--help` | Show this help message and exit. | -| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | -| `--chat` | Launch the web UI in chat mode. | | `--multi-user` | Multi-user mode. Chat histories are not saved or automatically loaded. WARNING: this is highly experimental. | | `--character CHARACTER` | The name of the character to load in chat mode by default. | | `--model MODEL` | Name of the model to load by default. | @@ -204,16 +226,16 @@ Optionally, you can use the following command-line flags: | `--model-dir MODEL_DIR` | Path to directory with all the models. | | `--lora-dir LORA_DIR` | Path to directory with all the loras. | | `--model-menu` | Show a model menu in the terminal when the web UI is first launched. | -| `--no-stream` | Don't stream the text output in real time. | | `--settings SETTINGS_FILE` | Load the default interface settings from this yaml file. See `settings-template.yaml` for an example. If you create a file called `settings.yaml`, this file will be loaded by default without the need to use the `--settings` flag. | | `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. | | `--verbose` | Print the prompts to the terminal. | +| `--chat-buttons` | Show buttons on chat tab instead of hover menu. | #### Model loader | Flag | Description | |--------------------------------------------|-------------| -| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, exllama_hf, llamacpp, rwkv, flexgen | +| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, exllama_hf, llamacpp, rwkv, ctransformers | #### Accelerate/transformers @@ -243,18 +265,33 @@ Optionally, you can use the following command-line flags: | `--quant_type QUANT_TYPE` | quant_type for 4-bit. Valid options: nf4, fp4. | | `--use_double_quant` | use_double_quant for 4-bit. | -#### llama.cpp +#### GGUF (for llama.cpp and ctransformers) | Flag | Description | |-------------|-------------| | `--threads` | Number of threads to use. | | `--n_batch` | Maximum number of prompt tokens to batch together when calling llama_eval. | -| `--no-mmap` | Prevent mmap from being used. | -| `--mlock` | Force the system to keep the model in RAM. | -| `--cache-capacity CACHE_CAPACITY` | Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed. | | `--n-gpu-layers N_GPU_LAYERS` | Number of layers to offload to the GPU. Only works if llama-cpp-python was compiled with BLAS. Set this to 1000000000 to offload all layers to the GPU. | | `--n_ctx N_CTX` | Size of the prompt context. | -| `--llama_cpp_seed SEED` | Seed for llama-cpp models. Default 0 (random). | + +#### llama.cpp + +| Flag | Description | +|---------------|---------------| +| `--no-mmap` | Prevent mmap from being used. | +| `--mlock` | Force the system to keep the model in RAM. | +| `--mul_mat_q` | Activate new mulmat kernels. | +| `--cache-capacity CACHE_CAPACITY` | Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed. | +| `--tensor_split TENSOR_SPLIT` | Split the model across multiple GPUs, comma-separated list of proportions, e.g. 18,17 | +| `--llama_cpp_seed SEED` | Seed for llama-cpp models. Default 0 (random). | +| `--cpu` | Use the CPU version of llama-cpp-python instead of the GPU-accelerated version. | +|`--cfg-cache` | llamacpp_HF: Create an additional cache for CFG negative prompts. | + +#### ctransformers + +| Flag | Description | +|-------------|-------------| +| `--model_type MODEL_TYPE` | Model type of pre-quantized model. Currently gpt2, gptj, gptneox, falcon, llama, mpt, starcoder (gptbigcode), dollyv2, and replit are supported. | #### AutoGPTQ @@ -265,6 +302,7 @@ Optionally, you can use the following command-line flags: | `--no_inject_fused_mlp` | Triton mode only: disable the use of fused MLP, which will use less VRAM at the cost of slower inference. | | `--no_use_cuda_fp16` | This can make models faster on some systems. | | `--desc_act` | For models that don't have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig. | +| `--disable_exllama` | Disable ExLlama kernel, which can improve inference speed on some systems. | #### ExLlama @@ -272,8 +310,7 @@ Optionally, you can use the following command-line flags: |------------------|-------------| |`--gpu-split` | Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. `20,7,7` | |`--max_seq_len MAX_SEQ_LEN` | Maximum sequence length. | -|`--compress_pos_emb COMPRESS_POS_EMB` | Positional embeddings compression factor. Should typically be set to max_seq_len / 2048. | -|`--alpha_value ALPHA_VALUE` | Positional embeddings alpha factor for NTK RoPE scaling. Same as above. Use either this or compress_pos_emb, not both. ` +|`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. | #### GPTQ-for-LLaMa @@ -285,17 +322,6 @@ Optionally, you can use the following command-line flags: | `--pre_layer PRE_LAYER [PRE_LAYER ...]` | The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. For multi-gpu, write the numbers separated by spaces, eg `--pre_layer 30 60`. | | `--checkpoint CHECKPOINT` | The path to the quantized checkpoint file. If not specified, it will be automatically detected. | | `--monkey-patch` | Apply the monkey patch for using LoRAs with quantized models. -| `--quant_attn` | (triton) Enable quant attention. | -| `--warmup_autotune` | (triton) Enable warmup autotune. | -| `--fused_mlp` | (triton) Enable fused mlp. | - -#### FlexGen - -| Flag | Description | -|------------------|-------------| -| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). | -| `--compress-weight` | FlexGen: Whether to compress weight (default: False).| -| `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). | #### DeepSpeed @@ -312,6 +338,14 @@ Optionally, you can use the following command-line flags: | `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". | | `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. | +#### RoPE (for llama.cpp, ExLlama, ExLlamaV2, and transformers) + +| Flag | Description | +|------------------|-------------| +| `--alpha_value ALPHA_VALUE` | Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both. | +| `--rope_freq_base ROPE_FREQ_BASE` | If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63). | +| `--compress_pos_emb COMPRESS_POS_EMB` | Positional embeddings compression factor. Should be set to (context length) / (model's original context length). Equal to 1/rope_freq_scale. | + #### Gradio | Flag | Description | @@ -323,6 +357,8 @@ Optionally, you can use the following command-line flags: | `--auto-launch` | Open the web UI in the default browser upon launch. | | `--gradio-auth USER:PWD` | set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3" | | `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" | +| `--ssl-keyfile SSL_KEYFILE` | The path to the SSL certificate key file. | +| `--ssl-certfile SSL_CERTFILE` | The path to the SSL certificate cert file. | #### API @@ -330,6 +366,7 @@ Optionally, you can use the following command-line flags: |---------------------------------------|-------------| | `--api` | Enable the API extension. | | `--public-api` | Create a public URL for the API using Cloudfare. | +| `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. | | `--api-blocking-port BLOCKING_PORT` | The listening port for the blocking API. | | `--api-streaming-port STREAMING_PORT` | The listening port for the streaming API. | @@ -339,8 +376,6 @@ Optionally, you can use the following command-line flags: |---------------------------------------|-------------| | `--multimodal-pipeline PIPELINE` | The multimodal pipeline to use. Examples: `llava-7b`, `llava-13b`. | -Out of memory errors? [Check the low VRAM guide](docs/Low-VRAM-guide.md). - ## Presets Inference settings presets can be created under `presets/` as yaml files. These files are detected automatically at startup. @@ -349,18 +384,13 @@ The presets that are included by default are the result of a contest that receiv ## Contributing -* Pull requests, suggestions, and issue reports are welcome. -* Make sure to carefully [search](https://github.com/oobabooga/text-generation-webui/issues) existing issues before starting a new one. -* If you have some experience with git, testing an open pull request and leaving a comment on whether it works as expected or not is immensely helpful. -* A simple way to contribute, even if you are not a programmer, is to leave a 👍 on an issue or pull request that you find relevant. +If you would like to contribute to the project, check out the [Contributing guidelines](https://github.com/oobabooga/text-generation-webui/wiki/Contributing-guidelines). ## Community -* Subreddit: https://www.reddit.com/r/oobaboogazz/ +* Subreddit: https://www.reddit.com/r/oobabooga/ * Discord: https://discord.gg/jwZCF2dPQN -## Credits +## Acknowledgment -- Gradio dropdown menu refresh button, code for reloading the interface: https://github.com/AUTOMATIC1111/stable-diffusion-webui -- Godlike preset: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets -- Code for some of the sliders: https://github.com/PygmalionAI/gradio-ui/ +In August 2023, [Andreessen Horowitz](https://a16z.com/) (a16z) provided a generous grant to encourage and support my independent work on this project. I am **extremely** grateful for their trust and recognition, which will allow me to dedicate more time towards realizing the full potential of text-generation-webui. diff --git a/api-examples/api-example-chat-stream.py b/api-examples/api-example-chat-stream.py index 8e37b56..bf4201c 100644 --- a/api-examples/api-example-chat-stream.py +++ b/api-examples/api-example-chat-stream.py @@ -1,4 +1,5 @@ import asyncio +import html import json import sys @@ -20,21 +21,28 @@ async def run(user_input, history): request = { 'user_input': user_input, 'max_new_tokens': 250, + 'auto_max_new_tokens': False, + 'max_tokens_second': 0, 'history': history, 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' 'character': 'Example', - 'instruction_template': 'Vicuna-v1.1', + 'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset 'your_name': 'You', - + # 'name1': 'name of user', # Optional + # 'name2': 'name of character', # Optional + # 'context': 'character context', # Optional + # 'greeting': 'greeting', # Optional + # 'name1_instruct': 'You', # Optional + # 'name2_instruct': 'Assistant', # Optional + # 'context_instruct': 'context_instruct', # Optional + # 'turn_template': 'turn_template', # Optional 'regenerate': False, '_continue': False, - 'stop_at_newline': False, - 'chat_generation_attempts': 1, - 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', + 'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', # Generation params. If 'preset' is set to different than 'None', the values # in presets/preset-name.yaml are used instead of the individual numbers. - 'preset': 'None', + 'preset': 'None', 'do_sample': True, 'temperature': 0.7, 'top_p': 0.1, @@ -55,11 +63,14 @@ async def run(user_input, history): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'guidance_scale': 1, + 'negative_prompt': '', 'seed': -1, 'add_bos_token': True, 'truncation_length': 2048, 'ban_eos_token': False, + 'custom_token_bans': '', 'skip_special_tokens': True, 'stopping_strings': [] } @@ -83,7 +94,7 @@ async def print_response_stream(user_input, history): async for new_history in run(user_input, history): cur_message = new_history['visible'][-1][1][cur_len:] cur_len += len(cur_message) - print(cur_message, end='') + print(html.unescape(cur_message), end='') sys.stdout.flush() # If we don't flush, we won't see tokens in realtime. diff --git a/api-examples/api-example-chat.py b/api-examples/api-example-chat.py index 23f2f18..42ba0a6 100644 --- a/api-examples/api-example-chat.py +++ b/api-examples/api-example-chat.py @@ -1,3 +1,4 @@ +import html import json import requests @@ -14,17 +15,24 @@ def run(user_input, history): request = { 'user_input': user_input, 'max_new_tokens': 250, + 'auto_max_new_tokens': False, + 'max_tokens_second': 0, 'history': history, 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' 'character': 'Example', - 'instruction_template': 'Vicuna-v1.1', + 'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset 'your_name': 'You', - + # 'name1': 'name of user', # Optional + # 'name2': 'name of character', # Optional + # 'context': 'character context', # Optional + # 'greeting': 'greeting', # Optional + # 'name1_instruct': 'You', # Optional + # 'name2_instruct': 'Assistant', # Optional + # 'context_instruct': 'context_instruct', # Optional + # 'turn_template': 'turn_template', # Optional 'regenerate': False, '_continue': False, - 'stop_at_newline': False, - 'chat_generation_attempts': 1, - 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', + 'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', # Generation params. If 'preset' is set to different than 'None', the values # in presets/preset-name.yaml are used instead of the individual numbers. @@ -49,11 +57,14 @@ def run(user_input, history): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'guidance_scale': 1, + 'negative_prompt': '', 'seed': -1, 'add_bos_token': True, 'truncation_length': 2048, 'ban_eos_token': False, + 'custom_token_bans': '', 'skip_special_tokens': True, 'stopping_strings': [] } @@ -64,7 +75,7 @@ def run(user_input, history): result = response.json()['results'][0]['history'] print(json.dumps(result, indent=4)) print() - print(result['visible'][-1][1]) + print(html.unescape(result['visible'][-1][1])) if __name__ == '__main__': diff --git a/api-examples/api-example-model.py b/api-examples/api-example-model.py index 8e1e300..44109d3 100644 --- a/api-examples/api-example-model.py +++ b/api-examples/api-example-model.py @@ -4,8 +4,9 @@ import requests HOST = '0.0.0.0:5000' -def generate(prompt, tokens = 200): - request = { 'prompt': prompt, 'max_new_tokens': tokens } + +def generate(prompt, tokens=200): + request = {'prompt': prompt, 'max_new_tokens': tokens} response = requests.post(f'http://{HOST}/api/v1/generate', json=request) if response.status_code == 200: @@ -23,7 +24,7 @@ def print_basic_model_info(response): print("Model: ", response['result']['model_name']) print("Lora(s): ", response['result']['lora_names']) for setting in basic_settings: - print(setting, "=", response['result']['shared.settings'][setting]) + print(setting, "=", response['result']['shared.settings'][setting]) # model info @@ -54,7 +55,7 @@ def complex_model_load(model): 'action': 'load', 'model_name': model, 'args': { - 'gptq_for_llama': False, # Use AutoGPTQ by default, set to True for gptq-for-llama + 'loader': 'AutoGPTQ', 'bf16': False, 'load_in_8bit': False, @@ -74,18 +75,18 @@ def complex_model_load(model): 'rwkv_strategy': None, 'rwkv_cuda_on': False, - # b&b 4-bit - #'load_in_4bit': False, - #'compute_dtype': 'float16', - #'quant_type': 'nf4', - #'use_double_quant': False, + # b&b 4-bit + # 'load_in_4bit': False, + # 'compute_dtype': 'float16', + # 'quant_type': 'nf4', + # 'use_double_quant': False, - #"cpu": false, - #"auto_devices": false, - #"gpu_memory": null, - #"cpu_memory": null, - #"disk": false, - #"disk_cache_dir": "cache", + # "cpu": false, + # "auto_devices": false, + # "gpu_memory": null, + # "cpu_memory": null, + # "disk": false, + # "disk_cache_dir": "cache", }, } @@ -104,26 +105,25 @@ def complex_model_load(model): req['args']['load_in_8bit'] = True elif '-hf' in model or 'fp16' in model: if '7b' in model: - req['args']['bf16'] = True # for 24GB + req['args']['bf16'] = True # for 24GB elif '13b' in model: - req['args']['load_in_8bit'] = True # for 24GB - elif 'ggml' in model: - #req['args']['threads'] = 16 + req['args']['load_in_8bit'] = True # for 24GB + elif 'gguf' in model: + # req['args']['threads'] = 16 if '7b' in model: req['args']['n_gpu_layers'] = 100 elif '13b' in model: req['args']['n_gpu_layers'] = 100 elif '30b' in model or '33b' in model: - req['args']['n_gpu_layers'] = 59 # 24GB + req['args']['n_gpu_layers'] = 59 # 24GB elif '65b' in model: - req['args']['n_gpu_layers'] = 42 # 24GB + req['args']['n_gpu_layers'] = 42 # 24GB elif 'rwkv' in model: req['args']['rwkv_cuda_on'] = True if '14b' in model: - req['args']['rwkv_strategy'] = 'cuda f16i8' # 24GB + req['args']['rwkv_strategy'] = 'cuda f16i8' # 24GB else: - req['args']['rwkv_strategy'] = 'cuda f16' # 24GB - + req['args']['rwkv_strategy'] = 'cuda f16' # 24GB return model_api(req) @@ -134,7 +134,7 @@ if __name__ == '__main__': resp = complex_model_load(model) if 'error' in resp: - print (f"❌ {model} FAIL Error: {resp['error']['message']}") + print(f"❌ {model} FAIL Error: {resp['error']['message']}") continue else: print_basic_model_info(resp) @@ -142,17 +142,17 @@ if __name__ == '__main__': ans = generate("0,1,1,2,3,5,8,13,", tokens=2) if '21' in ans: - print (f"✅ {model} PASS ({ans})") + print(f"✅ {model} PASS ({ans})") else: - print (f"❌ {model} FAIL ({ans})") + print(f"❌ {model} FAIL ({ans})") except Exception as e: - print (f"❌ {model} FAIL Exception: {repr(e)}") - + print(f"❌ {model} FAIL Exception: {repr(e)}") + # 0,1,1,2,3,5,8,13, is the fibonacci sequence, the next number is 21. # Some results below. -""" $ ./model-api-example.py +""" $ ./model-api-example.py Model: 4bit_gpt4-x-alpaca-13b-native-4bit-128g-cuda Lora(s): [] truncation_length = 2048 diff --git a/api-examples/api-example-stream.py b/api-examples/api-example-stream.py index 79a01e4..5382216 100644 --- a/api-examples/api-example-stream.py +++ b/api-examples/api-example-stream.py @@ -20,10 +20,12 @@ async def run(context): request = { 'prompt': context, 'max_new_tokens': 250, + 'auto_max_new_tokens': False, + 'max_tokens_second': 0, # Generation params. If 'preset' is set to different than 'None', the values # in presets/preset-name.yaml are used instead of the individual numbers. - 'preset': 'None', + 'preset': 'None', 'do_sample': True, 'temperature': 0.7, 'top_p': 0.1, @@ -44,11 +46,14 @@ async def run(context): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'guidance_scale': 1, + 'negative_prompt': '', 'seed': -1, 'add_bos_token': True, 'truncation_length': 2048, 'ban_eos_token': False, + 'custom_token_bans': '', 'skip_special_tokens': True, 'stopping_strings': [] } diff --git a/api-examples/api-example.py b/api-examples/api-example.py index b09823c..e6d79f9 100644 --- a/api-examples/api-example.py +++ b/api-examples/api-example.py @@ -12,10 +12,12 @@ def run(prompt): request = { 'prompt': prompt, 'max_new_tokens': 250, + 'auto_max_new_tokens': False, + 'max_tokens_second': 0, # Generation params. If 'preset' is set to different than 'None', the values # in presets/preset-name.yaml are used instead of the individual numbers. - 'preset': 'None', + 'preset': 'None', 'do_sample': True, 'temperature': 0.7, 'top_p': 0.1, @@ -36,11 +38,14 @@ def run(prompt): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'guidance_scale': 1, + 'negative_prompt': '', 'seed': -1, 'add_bos_token': True, 'truncation_length': 2048, 'ban_eos_token': False, + 'custom_token_bans': '', 'skip_special_tokens': True, 'stopping_strings': [] } diff --git a/characters/Assistant.yaml b/characters/Assistant.yaml new file mode 100644 index 0000000..a6141f4 --- /dev/null +++ b/characters/Assistant.yaml @@ -0,0 +1,4 @@ +name: AI +greeting: How can I help you today? +context: | + The following is a conversation with an AI Large Language Model. The AI has been trained to answer questions, provide recommendations, and help with decision making. The AI follows user requests. The AI thinks outside the box. diff --git a/characters/Example.yaml b/characters/Example.yaml index 0160f45..c1a3299 100644 --- a/characters/Example.yaml +++ b/characters/Example.yaml @@ -1,9 +1,10 @@ -name: "Chiharu Yamada" -context: "Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology." +name: Chiharu Yamada greeting: |- *Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air* Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started! -example_dialogue: |- +context: |- + Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology. + {{user}}: So how did you get into computer engineering? {{char}}: I've always loved tinkering with technology since I was a kid. {{user}}: That's really impressive! diff --git a/characters/instruction-following/WizardLM.yaml b/characters/instruction-following/WizardLM.yaml deleted file mode 100644 index c65bb8f..0000000 --- a/characters/instruction-following/WizardLM.yaml +++ /dev/null @@ -1,4 +0,0 @@ -user: "" -bot: "### Response:" -turn_template: "<|user-message|>\n\n<|bot|><|bot-message|>\n\n" -context: "" \ No newline at end of file diff --git a/convert-to-flexgen.py b/convert-to-flexgen.py deleted file mode 100644 index 7654593..0000000 --- a/convert-to-flexgen.py +++ /dev/null @@ -1,63 +0,0 @@ -''' - -Converts a transformers model to a format compatible with flexgen. - -''' - -import argparse -import os -from pathlib import Path - -import numpy as np -import torch -from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer - -parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54)) -parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") -args = parser.parse_args() - - -def disable_torch_init(): - """ - Disable the redundant torch default initialization to accelerate model creation. - """ - import torch - global torch_linear_init_backup - global torch_layer_norm_init_backup - - torch_linear_init_backup = torch.nn.Linear.reset_parameters - setattr(torch.nn.Linear, "reset_parameters", lambda self: None) - - torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters - setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) - - -def restore_torch_init(): - """Rollback the change made by disable_torch_init.""" - import torch - setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup) - setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup) - - -if __name__ == '__main__': - path = Path(args.MODEL) - model_name = path.name - - print(f"Loading {model_name}...") - # disable_torch_init() - model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True) - # restore_torch_init() - - tokenizer = AutoTokenizer.from_pretrained(path) - - out_folder = Path(f"models/{model_name}-np") - if not Path(out_folder).exists(): - os.mkdir(out_folder) - - print(f"Saving the converted model to {out_folder}...") - for name, param in tqdm(list(model.model.named_parameters())): - name = name.replace("decoder.final_layer_norm", "decoder.layer_norm") - param_path = os.path.join(out_folder, name) - with open(param_path, "wb") as f: - np.save(f, param.cpu().detach().numpy()) diff --git a/css/NotoSans/NotoSans-Black.woff b/css/NotoSans/NotoSans-Black.woff new file mode 100644 index 0000000..0280e0f Binary files /dev/null and b/css/NotoSans/NotoSans-Black.woff differ diff --git a/css/NotoSans/NotoSans-Black.woff2 b/css/NotoSans/NotoSans-Black.woff2 new file mode 100644 index 0000000..1d51183 Binary files /dev/null and b/css/NotoSans/NotoSans-Black.woff2 differ diff --git a/css/NotoSans/NotoSans-BlackItalic.woff b/css/NotoSans/NotoSans-BlackItalic.woff new file mode 100644 index 0000000..5cd4424 Binary files /dev/null and b/css/NotoSans/NotoSans-BlackItalic.woff differ diff --git a/css/NotoSans/NotoSans-BlackItalic.woff2 b/css/NotoSans/NotoSans-BlackItalic.woff2 new file mode 100644 index 0000000..f0baeca Binary files /dev/null and b/css/NotoSans/NotoSans-BlackItalic.woff2 differ diff --git a/css/NotoSans/NotoSans-Bold.woff b/css/NotoSans/NotoSans-Bold.woff new file mode 100644 index 0000000..750b737 Binary files /dev/null and b/css/NotoSans/NotoSans-Bold.woff differ diff --git a/css/NotoSans/NotoSans-Bold.woff2 b/css/NotoSans/NotoSans-Bold.woff2 new file mode 100644 index 0000000..af6c17d Binary files /dev/null and b/css/NotoSans/NotoSans-Bold.woff2 differ diff --git a/css/NotoSans/NotoSans-BoldItalic.woff b/css/NotoSans/NotoSans-BoldItalic.woff new file mode 100644 index 0000000..d484cf2 Binary files /dev/null and b/css/NotoSans/NotoSans-BoldItalic.woff differ diff --git a/css/NotoSans/NotoSans-BoldItalic.woff2 b/css/NotoSans/NotoSans-BoldItalic.woff2 new file mode 100644 index 0000000..210c3a1 Binary files /dev/null and b/css/NotoSans/NotoSans-BoldItalic.woff2 differ diff --git a/css/NotoSans/NotoSans-ExtraBold.woff b/css/NotoSans/NotoSans-ExtraBold.woff new file mode 100644 index 0000000..1a1e41d Binary files /dev/null and b/css/NotoSans/NotoSans-ExtraBold.woff differ diff --git a/css/NotoSans/NotoSans-ExtraBold.woff2 b/css/NotoSans/NotoSans-ExtraBold.woff2 new file mode 100644 index 0000000..e2bd323 Binary files /dev/null and b/css/NotoSans/NotoSans-ExtraBold.woff2 differ diff --git a/css/NotoSans/NotoSans-ExtraBoldItalic.woff b/css/NotoSans/NotoSans-ExtraBoldItalic.woff new file mode 100644 index 0000000..95d68a9 Binary files /dev/null and b/css/NotoSans/NotoSans-ExtraBoldItalic.woff differ diff --git a/css/NotoSans/NotoSans-ExtraBoldItalic.woff2 b/css/NotoSans/NotoSans-ExtraBoldItalic.woff2 new file mode 100644 index 0000000..65892ae Binary files /dev/null and b/css/NotoSans/NotoSans-ExtraBoldItalic.woff2 differ diff --git a/css/NotoSans/NotoSans-ExtraLight.woff b/css/NotoSans/NotoSans-ExtraLight.woff new file mode 100644 index 0000000..4b8a559 Binary files /dev/null and b/css/NotoSans/NotoSans-ExtraLight.woff differ diff --git a/css/NotoSans/NotoSans-ExtraLight.woff2 b/css/NotoSans/NotoSans-ExtraLight.woff2 new file mode 100644 index 0000000..e92cf55 Binary files /dev/null and b/css/NotoSans/NotoSans-ExtraLight.woff2 differ diff --git a/css/NotoSans/NotoSans-ExtraLightItalic.woff b/css/NotoSans/NotoSans-ExtraLightItalic.woff new file mode 100644 index 0000000..f0b0a67 Binary files /dev/null and b/css/NotoSans/NotoSans-ExtraLightItalic.woff differ diff --git a/css/NotoSans/NotoSans-ExtraLightItalic.woff2 b/css/NotoSans/NotoSans-ExtraLightItalic.woff2 new file mode 100644 index 0000000..d63c4f7 Binary files /dev/null and b/css/NotoSans/NotoSans-ExtraLightItalic.woff2 differ diff --git a/css/NotoSans/NotoSans-Italic.woff b/css/NotoSans/NotoSans-Italic.woff new file mode 100644 index 0000000..bc89297 Binary files /dev/null and b/css/NotoSans/NotoSans-Italic.woff differ diff --git a/css/NotoSans/NotoSans-Italic.woff2 b/css/NotoSans/NotoSans-Italic.woff2 new file mode 100644 index 0000000..a6bd8a3 Binary files /dev/null and b/css/NotoSans/NotoSans-Italic.woff2 differ diff --git a/css/NotoSans/NotoSans-Light.woff b/css/NotoSans/NotoSans-Light.woff new file mode 100644 index 0000000..b89c997 Binary files /dev/null and b/css/NotoSans/NotoSans-Light.woff differ diff --git a/css/NotoSans/NotoSans-Light.woff2 b/css/NotoSans/NotoSans-Light.woff2 new file mode 100644 index 0000000..962c6d7 Binary files /dev/null and b/css/NotoSans/NotoSans-Light.woff2 differ diff --git a/css/NotoSans/NotoSans-LightItalic.woff b/css/NotoSans/NotoSans-LightItalic.woff new file mode 100644 index 0000000..741ab91 Binary files /dev/null and b/css/NotoSans/NotoSans-LightItalic.woff differ diff --git a/css/NotoSans/NotoSans-LightItalic.woff2 b/css/NotoSans/NotoSans-LightItalic.woff2 new file mode 100644 index 0000000..9153283 Binary files /dev/null and b/css/NotoSans/NotoSans-LightItalic.woff2 differ diff --git a/css/NotoSans/NotoSans-Medium.woff b/css/NotoSans/NotoSans-Medium.woff new file mode 100644 index 0000000..d8dfb11 Binary files /dev/null and b/css/NotoSans/NotoSans-Medium.woff differ diff --git a/css/NotoSans/NotoSans-Medium.woff2 b/css/NotoSans/NotoSans-Medium.woff2 new file mode 100644 index 0000000..deff785 Binary files /dev/null and b/css/NotoSans/NotoSans-Medium.woff2 differ diff --git a/css/NotoSans/NotoSans-MediumItalic.woff b/css/NotoSans/NotoSans-MediumItalic.woff new file mode 100644 index 0000000..d7ca037 Binary files /dev/null and b/css/NotoSans/NotoSans-MediumItalic.woff differ diff --git a/css/NotoSans/NotoSans-MediumItalic.woff2 b/css/NotoSans/NotoSans-MediumItalic.woff2 new file mode 100644 index 0000000..d87d4b6 Binary files /dev/null and b/css/NotoSans/NotoSans-MediumItalic.woff2 differ diff --git a/css/NotoSans/NotoSans-Regular.woff b/css/NotoSans/NotoSans-Regular.woff new file mode 100644 index 0000000..64d9e17 Binary files /dev/null and b/css/NotoSans/NotoSans-Regular.woff differ diff --git a/css/NotoSans/NotoSans-Regular.woff2 b/css/NotoSans/NotoSans-Regular.woff2 new file mode 100644 index 0000000..172de3c Binary files /dev/null and b/css/NotoSans/NotoSans-Regular.woff2 differ diff --git a/css/NotoSans/NotoSans-SemiBold.woff b/css/NotoSans/NotoSans-SemiBold.woff new file mode 100644 index 0000000..abd6f54 Binary files /dev/null and b/css/NotoSans/NotoSans-SemiBold.woff differ diff --git a/css/NotoSans/NotoSans-SemiBold.woff2 b/css/NotoSans/NotoSans-SemiBold.woff2 new file mode 100644 index 0000000..1c38d67 Binary files /dev/null and b/css/NotoSans/NotoSans-SemiBold.woff2 differ diff --git a/css/NotoSans/NotoSans-SemiBoldItalic.woff b/css/NotoSans/NotoSans-SemiBoldItalic.woff new file mode 100644 index 0000000..32dd019 Binary files /dev/null and b/css/NotoSans/NotoSans-SemiBoldItalic.woff differ diff --git a/css/NotoSans/NotoSans-SemiBoldItalic.woff2 b/css/NotoSans/NotoSans-SemiBoldItalic.woff2 new file mode 100644 index 0000000..853adbf Binary files /dev/null and b/css/NotoSans/NotoSans-SemiBoldItalic.woff2 differ diff --git a/css/NotoSans/NotoSans-Thin.woff b/css/NotoSans/NotoSans-Thin.woff new file mode 100644 index 0000000..edb17d3 Binary files /dev/null and b/css/NotoSans/NotoSans-Thin.woff differ diff --git a/css/NotoSans/NotoSans-Thin.woff2 b/css/NotoSans/NotoSans-Thin.woff2 new file mode 100644 index 0000000..ca9fd92 Binary files /dev/null and b/css/NotoSans/NotoSans-Thin.woff2 differ diff --git a/css/NotoSans/NotoSans-ThinItalic.woff b/css/NotoSans/NotoSans-ThinItalic.woff new file mode 100644 index 0000000..43d33e8 Binary files /dev/null and b/css/NotoSans/NotoSans-ThinItalic.woff differ diff --git a/css/NotoSans/NotoSans-ThinItalic.woff2 b/css/NotoSans/NotoSans-ThinItalic.woff2 new file mode 100644 index 0000000..73d94ab Binary files /dev/null and b/css/NotoSans/NotoSans-ThinItalic.woff2 differ diff --git a/css/NotoSans/stylesheet.css b/css/NotoSans/stylesheet.css new file mode 100644 index 0000000..467973b --- /dev/null +++ b/css/NotoSans/stylesheet.css @@ -0,0 +1,166 @@ +/* +Copied from https://github.com/SillyTavern/SillyTavern/tree/6c8bd06308c69d51e2eb174541792a870a83d2d6/public/webfonts/NotoSans +*/ + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-Black.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-Black.woff') format('woff'); + font-weight: 900; + font-style: normal; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-ExtraBoldItalic.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-ExtraBoldItalic.woff') format('woff'); + font-weight: bold; + font-style: italic; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-BlackItalic.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-BlackItalic.woff') format('woff'); + font-weight: 900; + font-style: italic; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-ExtraBold.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-ExtraBold.woff') format('woff'); + font-weight: bold; + font-style: normal; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-ThinItalic.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-ThinItalic.woff') format('woff'); + font-weight: 100; + font-style: italic; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-BoldItalic.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-BoldItalic.woff') format('woff'); + font-weight: bold; + font-style: italic; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-Bold.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-Bold.woff') format('woff'); + font-weight: bold; + font-style: normal; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-LightItalic.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-LightItalic.woff') format('woff'); + font-weight: 300; + font-style: italic; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-Italic.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-Italic.woff') format('woff'); + font-weight: normal; + font-style: italic; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-ExtraLightItalic.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-ExtraLightItalic.woff') format('woff'); + font-weight: 200; + font-style: italic; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-Light.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-Light.woff') format('woff'); + font-weight: 300; + font-style: normal; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-ExtraLight.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-ExtraLight.woff') format('woff'); + font-weight: 200; + font-style: normal; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-Medium.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-Medium.woff') format('woff'); + font-weight: 500; + font-style: normal; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-Regular.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-Regular.woff') format('woff'); + font-weight: normal; + font-style: normal; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-MediumItalic.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-MediumItalic.woff') format('woff'); + font-weight: 500; + font-style: italic; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-SemiBoldItalic.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-SemiBoldItalic.woff') format('woff'); + font-weight: 600; + font-style: italic; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-SemiBold.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-SemiBold.woff') format('woff'); + font-weight: 600; + font-style: normal; + font-display: swap; +} + +@font-face { + font-family: 'Noto Sans'; + src: url('file/css/NotoSans/NotoSans-Thin.woff2') format('woff2'), + url('file/css/NotoSans/NotoSans-Thin.woff') format('woff'); + font-weight: 100; + font-style: normal; + font-display: swap; +} + diff --git a/css/chat.css b/css/chat.css deleted file mode 100644 index 45a518b..0000000 --- a/css/chat.css +++ /dev/null @@ -1,126 +0,0 @@ -.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx { - height: 66.67vh -} - -.gradio-container { - margin-left: auto !important; - margin-right: auto !important; -} - -.w-screen { - width: unset -} - -div.svelte-362y77>*, div.svelte-362y77>.form>* { - flex-wrap: nowrap -} - -/* fixes the API documentation in chat mode */ -.api-docs.svelte-1iguv9h.svelte-1iguv9h.svelte-1iguv9h { - display: grid; -} - -.pending.svelte-1ed2p3z { - opacity: 1; -} - -#extensions { - padding: 0; - padding: 0; -} - -#gradio-chatbot { - height: 66.67vh; -} - -.wrap.svelte-6roggh.svelte-6roggh { - max-height: 92.5%; -} - -/* This is for the microphone button in the whisper extension */ -.sm.svelte-1ipelgc { - width: 100%; -} - -#main button { - min-width: 0 !important; -} - -/*****************************************************/ -/*************** Chat box declarations ***************/ -/*****************************************************/ - -.chat { - margin-left: auto; - margin-right: auto; - max-width: 800px; - height: calc(100vh - 296px); - overflow-y: auto; - padding-right: 20px; - display: flex; - flex-direction: column-reverse; - word-break: break-word; - overflow-wrap: anywhere; - padding-top: 1px; -} - -.message-body li { - margin-top: 0.5em !important; - margin-bottom: 0.5em !important; -} - -.message-body li > p { - display: inline !important; -} - -.message-body ul, .message-body ol { - font-size: 15px !important; -} - -.message-body ul { - list-style-type: disc !important; -} - -.message-body pre { - margin-bottom: 1.25em !important; -} - -.message-body code { - white-space: pre-wrap !important; - word-wrap: break-word !important; -} - -.message-body :not(pre) > code { - white-space: normal !important; -} - -@media print { - body { - visibility: hidden; - } - - .chat { - visibility: visible; - position: absolute; - left: 0; - top: 0; - max-width: none; - max-height: none; - width: 100%; - height: fit-content; - display: flex; - flex-direction: column-reverse; - } - - .message { - break-inside: avoid; - } - - .gradio-container { - overflow: visible; - } - - .tab-nav { - display: none !important; - } -} diff --git a/css/chat.js b/css/chat.js deleted file mode 100644 index e304f12..0000000 --- a/css/chat.js +++ /dev/null @@ -1,4 +0,0 @@ -document.getElementById("main").childNodes[0].style = "max-width: 800px; margin-left: auto; margin-right: auto"; -document.getElementById("extensions").style.setProperty("max-width", "800px"); -document.getElementById("extensions").style.setProperty("margin-left", "auto"); -document.getElementById("extensions").style.setProperty("margin-right", "auto"); diff --git a/css/chat_style-TheEncrypted777.css b/css/chat_style-TheEncrypted777.css index 7682011..dfc01eb 100644 --- a/css/chat_style-TheEncrypted777.css +++ b/css/chat_style-TheEncrypted777.css @@ -5,22 +5,14 @@ grid-template-columns: 60px minmax(0, 1fr); padding-bottom: 28px; font-size: 18px; - /*Change 'Quicksand' to a font you like or leave it*/ - font-family: Quicksand, Arial, sans-serif; + font-family: 'Noto Sans', Arial, sans-serif; line-height: 1.428571429; } -.circle-you { - background-color: gray; - border-radius: 1rem; - /*Change color to any you like to be the border of your image*/ - border: 2px solid white; -} - +.circle-you, .circle-bot { background-color: gray; border-radius: 1rem; - /*Change color to any you like to be the border of the bot's image*/ border: 2px solid white; } @@ -41,7 +33,7 @@ .text { /*Change this to move the message box further left or right depending on the size of your profile pic*/ padding-left: 90px; - text-shadow: 2px 2px 2px rgb(0, 0, 0); + text-shadow: 2px 2px 2px rgb(0, 0, 0, 0.4); } .text p { @@ -96,12 +88,46 @@ margin-bottom: 0 !important; font-size: 18px !important; line-height: 1.428571429 !important; -} - -.dark .message-body p em { - color: rgb(138, 138, 138) !important; + color: rgb(243, 244, 246) !important; + text-shadow: 2px 2px 2px rgb(0, 0, 0); } .message-body p em { - color: rgb(110, 110, 110) !important; + color: rgb(138, 138, 138) !important; +} + +@media screen and (max-width: 688px) { + .message { + display: grid; + grid-template-columns: 60px minmax(0, 1fr); + padding-bottom: 25px; + font-size: 15px; + font-family: 'Noto Sans', Helvetica, Arial, sans-serif; + line-height: 1.428571429; + } + + .circle-you, .circle-bot { + width: 50px; + height: 73px; + border-radius: 0.5rem; + } + + .circle-bot img, + .circle-you img { + width: 100%; + height: 100%; + object-fit: cover; + } + + .text { + padding-left: 0px; + } + + .message-body p { + font-size: 16px !important; + } + + .username { + font-size: 20px; + } } diff --git a/css/chat_style-cai-chat-square.css b/css/chat_style-cai-chat-square.css new file mode 100644 index 0000000..0098da3 --- /dev/null +++ b/css/chat_style-cai-chat-square.css @@ -0,0 +1,21 @@ +@import url("file/css/chat_style-cai-chat.css"); + +.circle-bot, .circle-you { + height: 90px; + width: 60px; + border-radius: 10px; + background-color: #656565; +} + +.circle-bot img, .circle-you img { + border-radius: 8.333px; +} + +.circle-you { + background-color: #656565; +} + +.message { + padding-bottom: 30px; + grid-template-columns: 70px minmax(0, 1fr); +} diff --git a/css/chat_style-cai-chat.css b/css/chat_style-cai-chat.css index d48fe76..47f39e0 100644 --- a/css/chat_style-cai-chat.css +++ b/css/chat_style-cai-chat.css @@ -3,8 +3,8 @@ grid-template-columns: 60px minmax(0, 1fr); padding-bottom: 25px; font-size: 15px; - font-family: Helvetica, Arial, sans-serif; - line-height: 1.428571429; + font-family: 'Noto Sans', Helvetica, Arial, sans-serif; + line-height: 23px !important; } .circle-you { @@ -46,7 +46,7 @@ .message-body p { margin-bottom: 0 !important; font-size: 15px !important; - line-height: 1.428571429 !important; + line-height: 23px !important; } .dark .message-body p em { @@ -55,4 +55,5 @@ .message-body p em { color: rgb(110, 110, 110) !important; + font-weight: 500; } \ No newline at end of file diff --git a/css/chat_style-messenger.css b/css/chat_style-messenger.css index 0e5528d..fb3f65a 100644 --- a/css/chat_style-messenger.css +++ b/css/chat_style-messenger.css @@ -1,7 +1,7 @@ .message { padding-bottom: 25px; font-size: 15px; - font-family: Helvetica, Arial, sans-serif; + font-family: 'Noto Sans', Helvetica, Arial, sans-serif; line-height: 1.428571429; } diff --git a/css/chat_style-wpp.css b/css/chat_style-wpp.css index 14b4087..da9f172 100644 --- a/css/chat_style-wpp.css +++ b/css/chat_style-wpp.css @@ -1,7 +1,7 @@ .message { padding-bottom: 25px; font-size: 15px; - font-family: Helvetica, Arial, sans-serif; + font-family: 'Noto Sans', Helvetica, Arial, sans-serif; line-height: 1.428571429; } diff --git a/css/html_4chan_style.css b/css/html_4chan_style.css index 99ac684..cef9f6e 100644 --- a/css/html_4chan_style.css +++ b/css/html_4chan_style.css @@ -98,7 +98,7 @@ margin-right: 40px !important; } -#parent #container .message { +#parent #container .message_4chan { color: black; border: none; } \ No newline at end of file diff --git a/css/html_instruct_style.css b/css/html_instruct_style.css index 575281b..286029f 100644 --- a/css/html_instruct_style.css +++ b/css/html_instruct_style.css @@ -3,8 +3,8 @@ grid-template-columns: 60px 1fr; padding-bottom: 25px; font-size: 15px; - font-family: Helvetica, Arial, sans-serif; - line-height: 1.428571429; + font-family: 'Noto Sans', Helvetica, Arial, sans-serif; + line-height: 22px; } .username { @@ -13,11 +13,11 @@ .message-body p { font-size: 15px !important; - line-height: 1.75 !important; + line-height: 22px !important; margin-bottom: 1.25em !important; } -.message-body ul, .message-body ol { +.chat .message-body ul, .chat .message-body ol { margin-bottom: 1.25em !important; } @@ -43,14 +43,16 @@ margin-bottom: 9px !important; } +.gradio-container .chat .assistant-message:last-child, .gradio-container .chat .user-message:last-child { + margin-bottom: 0px !important; +} + .dark .chat .assistant-message { - background-color: #3741519e; - border: 1px solid #4b5563; + background-color: #1f2937; } .dark .chat .user-message { - background-color: #111827; - border: 1px solid #4b5563; + background-color: transparent; } code { @@ -58,5 +60,5 @@ code { } .dark code { - background-color: #1a212f !important; + background-color: #0e1321 !important; } \ No newline at end of file diff --git a/css/html_readable_style.css b/css/html_readable_style.css index cd5fca9..2cfa6f2 100644 --- a/css/html_readable_style.css +++ b/css/html_readable_style.css @@ -26,4 +26,8 @@ .container :not(pre) > code { white-space: normal !important; +} + +.container .hoverable { + font-size: 14px; } \ No newline at end of file diff --git a/css/main.css b/css/main.css index a00147c..da0e381 100644 --- a/css/main.css +++ b/css/main.css @@ -7,6 +7,7 @@ } .small-button { + min-width: 0 !important; max-width: 171px; height: 39.594px; align-self: end; @@ -26,6 +27,10 @@ max-width: 2.2em; } +.button_nowrap { + white-space: nowrap; +} + #slim-column { flex: none !important; min-width: 0 !important; @@ -41,9 +46,6 @@ min-height: 0 } -#accordion { -} - .dark svg { fill: white; } @@ -56,7 +58,7 @@ ol li p, ul li p { display: inline-block; } -#main, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab { +#chat-tab, #default-tab, #notebook-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab { border: 0; } @@ -70,7 +72,7 @@ ol li p, ul li p { } #extensions { - padding: 15px; + margin-top: 5px; margin-bottom: 35px; } @@ -89,7 +91,11 @@ div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * { .header_bar { background-color: #f7f7f7; - margin-bottom: 30px; + margin-bottom: 19px; + display: inline !important; + overflow-x: scroll; + margin-left: calc(-1 * var(--size-4)); + margin-right: calc(-1 * var(--size-4)); } .dark .header_bar { @@ -97,19 +103,39 @@ div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * { background-color: #8080802b; } +.header_bar button.selected { + border-radius: 0; +} + .textbox_default textarea { - height: calc(100vh - 390px); + height: calc(100dvh - 271px); } .textbox_default_output textarea { - height: calc(100vh - 200px); + height: calc(100dvh - 185px); } .textbox textarea { - height: calc(100vh - 251px); + height: calc(100dvh - 241px); } -.textbox_default textarea, .textbox_default_output textarea, .textbox textarea { +.textbox_logits textarea { + height: calc(100dvh - 236px); +} + +.textbox_logits_notebook textarea { + height: calc(100dvh - 292px); +} + +.monospace { + font-family: monospace; +} + +.textbox_default textarea, +.textbox_default_output textarea, +.textbox_logits textarea, +.textbox_logits_notebook textarea, +.textbox textarea { font-size: 16px !important; color: #46464A !important; } @@ -118,6 +144,24 @@ div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * { color: #efefef !important; } +@media screen and (max-width: 711px) { + .textbox_default textarea { + height: calc(100dvh - 259px); + } + + div .default-token-counter { + top: calc( 0.5 * (100dvh - 236px) ) !important; + } + + .transparent-substring { + display: none; + } + + .hover-menu { + min-width: 250px !important; + } +} + /* Hide the gradio footer*/ footer { display: none !important; @@ -154,4 +198,406 @@ button { .markdown ul ol { font-size: 100% !important; -} \ No newline at end of file +} + +.pretty_scrollbar::-webkit-scrollbar { + width: 5px; +} + +.pretty_scrollbar::-webkit-scrollbar-track { + background: transparent; +} + +.pretty_scrollbar::-webkit-scrollbar-thumb, +.pretty_scrollbar::-webkit-scrollbar-thumb:hover { + background: #c5c5d2; +} + +.dark .pretty_scrollbar::-webkit-scrollbar-thumb, +.dark .pretty_scrollbar::-webkit-scrollbar-thumb:hover { + background: #374151; +} + +.pretty_scrollbar::-webkit-resizer { + background: #c5c5d2; +} + +.dark .pretty_scrollbar::-webkit-resizer { + background: #374151; +} + +audio { + max-width: 100%; +} + +/* Copied from https://github.com/AUTOMATIC1111/stable-diffusion-webui */ +.token-counter { + position: absolute !important; + top: calc( 0.5 * (100dvh - 218px) ) !important; + right: 2px; + z-index: 100; + background: var(--input-background-fill) !important; + min-height: 0 !important; +} + +.default-token-counter { + top: calc( 0.5 * (100dvh - 248px) ) !important; +} + +.token-counter span { + padding: 1px; + box-shadow: 0 0 0 0.3em rgba(192,192,192,0.15), inset 0 0 0.6em rgba(192,192,192,0.075); + border: 2px solid rgba(192,192,192,0.4) !important; + border-radius: 0.4em; +} + +.no-background { + background: var(--background-fill-primary) !important; + padding: 0px !important; +} + +/*****************************************************/ +/*************** Chat UI declarations ****************/ +/*****************************************************/ + +.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx { + height: 66.67vh +} + +.gradio-container { + margin-left: auto !important; + margin-right: auto !important; +} + +.w-screen { + width: unset +} + +div.svelte-362y77>*, div.svelte-362y77>.form>* { + flex-wrap: nowrap +} + +.pending.svelte-1ed2p3z { + opacity: 1; +} + +.wrap.svelte-6roggh.svelte-6roggh { + max-height: 92.5%; +} + +/* This is for the microphone button in the whisper extension */ +.sm.svelte-1ipelgc { + width: 100%; +} + +#chat-tab button#Generate, #chat-tab button#stop { + width: 89.3438px !important; +} + +#chat-tab button, #notebook-tab button, #default-tab button { + min-width: 0 !important; +} + +#chat-tab > :first-child, #extensions { + max-width: 880px; + margin-left: auto; + margin-right: auto; +} + +@media screen and (max-width: 688px) { + #chat-tab { + padding-left: 0px; + padding-right: 0px; + } + + .chat-parent { + height: calc(100dvh - 179px) !important; + } + + .old-ui .chat-parent { + height: calc(100dvh - 310px) !important; + } +} + +.chat { + margin-left: auto; + margin-right: auto; + max-width: 880px; + height: 100%; + overflow-y: auto; + padding-right: 15px; + display: flex; + flex-direction: column; + word-break: break-word; + overflow-wrap: anywhere; +} + +.chat-parent { + height: calc(100dvh - 181px); + overflow: auto !important; +} + +.old-ui .chat-parent { + height: calc(100dvh - 270px); +} + +.chat-parent.bigchat { + height: calc(100dvh - 181px) !important; +} + +.chat > .messages { + display: flex; + flex-direction: column; +} + +.chat .message:last-child { + margin-bottom: 0px !important; + padding-bottom: 0px !important; +} + +.message-body li { + margin-top: 0 !important; + margin-bottom: 0 !important; +} + +.message-body li > p { + display: inline !important; +} + +.message-body ul, .message-body ol { + font-size: 15px !important; +} + +.message-body ul { + list-style-type: disc !important; +} + +.message-body pre { + margin-bottom: 1.25em !important; +} + +.message-body code { + white-space: pre-wrap !important; + word-wrap: break-word !important; +} + +.message-body :not(pre) > code { + white-space: normal !important; +} + +#chat-input { + padding: 0; + padding-top: 18px; + background: transparent; + border: none; +} + +#chat-input textarea:focus { + box-shadow: none !important; +} + +@media print { + body { + visibility: hidden; + } + + .chat { + visibility: visible; + position: absolute; + left: 0; + top: 0; + max-width: unset; + max-height: unset; + width: 100%; + overflow-y: visible; + } + + .message { + break-inside: avoid; + } + + .gradio-container { + overflow: visible; + } + + .tab-nav { + display: none !important; + } + + #chat-tab > :first-child { + max-width: unset; + } +} + +#show-controls { + position: absolute; + height: 100%; + background-color: var(--background-fill-primary); + border: 0px; + border-radius: 0px; +} + +#show-controls label { + z-index: 1000; + position: absolute; + left: calc(100% - 168px); +} + +#typing-container { + display: none; + position: absolute; + background-color: transparent; + left: -2px; + padding: var(--block-padding); +} + +.typing { + position: relative; +} + +.visible-dots #typing-container { + display: block; +} + +.typing span { + content: ''; + animation: blink 1.5s infinite; + animation-fill-mode: both; + height: 10px; + width: 10px; + background: #3b5998;; + position: absolute; + left:0; + top:0; + border-radius: 50%; +} + +.typing .dot1 { + animation-delay: .2s; + margin-left: calc(10px * 1.5); +} + +.typing .dot2 { + animation-delay: .4s; + margin-left: calc(10px * 3); +} + +@keyframes blink { + 0% { + opacity: .1; + } + 20% { + opacity: 1; + } + 100% { + opacity: .1; + } +} + +#chat-tab .generating { + display: none !important; +} + +.hover-element { + position: relative; + font-size: 24px; +} + +.hover-menu { + display: none; + position: absolute; + bottom: 80%; + left: 0; + background-color: var(--background-fill-secondary); + box-shadow: 0 0 10px rgba(0, 0, 0, 0.5); + z-index: 10000; + min-width: 330px; + flex-direction: column; +} + +.hover-menu button { + width: 100%; + background: transparent !important; + border-radius: 0px !important; + justify-content: space-between; + margin: 0 !important; + height: 36px; +} + +.hover-menu button:not(#clear-history-confirm) { + border-bottom: 0 !important; +} + +.hover-menu button:not(#clear-history-confirm):last-child { + border-bottom: var(--button-border-width) solid var(--button-secondary-border-color) !important; +} + +.hover-menu button:hover { + background: var(--button-secondary-background-fill-hover) !important; +} + +.transparent-substring { + opacity: 0.333; +} + +#chat-tab:not(.old-ui) #chat-buttons { + display: none !important; +} + +#gr-hover-container { + min-width: 0 !important; + display: flex; + flex-direction: column-reverse; + padding-right: 20px; + padding-bottom: 3px; + flex-grow: 0 !important; +} + +#generate-stop-container { + min-width: 0 !important; + display: flex; + flex-direction: column-reverse; + padding-bottom: 3px; + flex: 0 auto !important; +} + +#chat-input-container { + min-width: 0 !important; +} + +#chat-input-container > .form { + background: transparent; + border: none; +} + +#chat-input-row { + padding-bottom: 20px; +} + +.old-ui #chat-input-row, #chat-input-row.bigchat { + padding-bottom: 0px !important; +} + +#chat-col { + padding-bottom: 115px; +} + +.old-ui #chat-col, #chat-col.bigchat { + padding-bottom: 95px !important; +} + +.old-ui #chat-buttons #clear-history-confirm { + order: -1; +} + +.chat ol, .chat ul { + margin-top: 6px !important; +} + +#past-chats-row { + margin-bottom: calc( -1 * var(--layout-gap) ); +} + +#rename-row label { + margin-top: var(--layout-gap); +} diff --git a/css/main.js b/css/main.js deleted file mode 100644 index 32820eb..0000000 --- a/css/main.js +++ /dev/null @@ -1,18 +0,0 @@ -document.getElementById("main").parentNode.childNodes[0].classList.add("header_bar"); -document.getElementById("main").parentNode.style = "padding: 0; margin: 0"; -document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0"; - -// Get references to the elements -let main = document.getElementById('main'); -let main_parent = main.parentNode; -let extensions = document.getElementById('extensions'); - -// Add an event listener to the main element -main_parent.addEventListener('click', function(e) { - // Check if the main element is visible - if (main.offsetHeight > 0 && main.offsetWidth > 0) { - extensions.style.display = 'flex'; - } else { - extensions.style.display = 'none'; - } -}); diff --git a/docker/Dockerfile b/docker/Dockerfile index 7cc0ff1..810bb7c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,22 +1,23 @@ FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as builder -RUN apt-get update && \ +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,rw apt-get update && \ apt-get install --no-install-recommends -y git vim build-essential python3-dev python3-venv && \ rm -rf /var/lib/apt/lists/* -RUN git clone https://github.com/oobabooga/GPTQ-for-LLaMa /build +RUN git clone --depth=1 https://github.com/oobabooga/GPTQ-for-LLaMa /build WORKDIR /build -RUN python3 -m venv /build/venv -RUN . /build/venv/bin/activate && \ +RUN --mount=type=cache,target=/root/.cache/pip,rw \ + python3 -m venv /build/venv && \ + . /build/venv/bin/activate && \ pip3 install --upgrade pip setuptools wheel && \ pip3 install torch torchvision torchaudio && \ pip3 install -r requirements.txt # https://developer.nvidia.com/cuda-gpus # for a rtx 2060: ARG TORCH_CUDA_ARCH_LIST="7.5" -ARG TORCH_CUDA_ARCH_LIST="3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX" +ARG TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX}" RUN . /build/venv/bin/activate && \ python3 setup_cuda.py bdist_wheel -d . @@ -25,11 +26,11 @@ FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04 LABEL maintainer="Your Name " LABEL description="Docker image for GPTQ-for-LLaMa and Text Generation WebUI" -RUN apt-get update && \ - apt-get install --no-install-recommends -y python3-dev libportaudio2 libasound-dev git python3 python3-pip make g++ && \ +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,rw apt-get update && \ + apt-get install --no-install-recommends -y python3-dev libportaudio2 libasound-dev git python3 python3-pip make g++ ffmpeg && \ rm -rf /var/lib/apt/lists/* -RUN --mount=type=cache,target=/root/.cache/pip pip3 install virtualenv +RUN --mount=type=cache,target=/root/.cache/pip,rw pip3 install virtualenv RUN mkdir /app WORKDIR /app @@ -37,32 +38,38 @@ WORKDIR /app ARG WEBUI_VERSION RUN test -n "${WEBUI_VERSION}" && git reset --hard ${WEBUI_VERSION} || echo "Using provided webui source" +# Create virtualenv RUN virtualenv /app/venv -RUN . /app/venv/bin/activate && \ +RUN --mount=type=cache,target=/root/.cache/pip,rw \ + . /app/venv/bin/activate && \ pip3 install --upgrade pip setuptools wheel && \ - pip3 install torch torchvision torchaudio + pip3 install torch torchvision torchaudio sentence_transformers xformers +# Copy and install GPTQ-for-LLaMa COPY --from=builder /build /app/repositories/GPTQ-for-LLaMa -RUN . /app/venv/bin/activate && \ +RUN --mount=type=cache,target=/root/.cache/pip,rw \ + . /app/venv/bin/activate && \ pip3 install /app/repositories/GPTQ-for-LLaMa/*.whl -COPY extensions/api/requirements.txt /app/extensions/api/requirements.txt -COPY extensions/elevenlabs_tts/requirements.txt /app/extensions/elevenlabs_tts/requirements.txt -COPY extensions/google_translate/requirements.txt /app/extensions/google_translate/requirements.txt -COPY extensions/silero_tts/requirements.txt /app/extensions/silero_tts/requirements.txt -COPY extensions/whisper_stt/requirements.txt /app/extensions/whisper_stt/requirements.txt -RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/api && pip3 install -r requirements.txt -RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/elevenlabs_tts && pip3 install -r requirements.txt -RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/google_translate && pip3 install -r requirements.txt -RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/silero_tts && pip3 install -r requirements.txt -RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/whisper_stt && pip3 install -r requirements.txt - +# Install main requirements COPY requirements.txt /app/requirements.txt -RUN . /app/venv/bin/activate && \ +RUN --mount=type=cache,target=/root/.cache/pip,rw \ + . /app/venv/bin/activate && \ pip3 install -r requirements.txt +COPY . /app/ + RUN cp /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so -COPY . /app/ +# Install extension requirements +RUN --mount=type=cache,target=/root/.cache/pip,rw \ + . /app/venv/bin/activate && \ + for ext in /app/extensions/*/requirements.txt; do \ + cd "$(dirname "$ext")"; \ + pip3 install -r requirements.txt; \ + done + ENV CLI_ARGS="" + +EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} ${CONTAINER_API_STREAM_PORT:-5005} CMD . /app/venv/bin/activate && python3 server.py ${CLI_ARGS} diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index bc59dc3..ce29f33 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -5,13 +5,13 @@ services: context: . args: # specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus - TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST} - WEBUI_VERSION: ${WEBUI_VERSION} + TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-7.5} + WEBUI_VERSION: ${WEBUI_VERSION:-HEAD} env_file: .env ports: - - "${HOST_PORT}:${CONTAINER_PORT}" - - "${HOST_API_PORT}:${CONTAINER_API_PORT}" - - "${HOST_API_STREAM_PORT}:${CONTAINER_API_STREAM_PORT}" + - "${HOST_PORT:-7860}:${CONTAINER_PORT:-7860}" + - "${HOST_API_PORT:-5000}:${CONTAINER_API_PORT:-5000}" + - "${HOST_API_STREAM_PORT:-5005}:${CONTAINER_API_STREAM_PORT:-5005}" stdin_open: true tty: true volumes: @@ -23,6 +23,7 @@ services: - ./prompts:/app/prompts - ./softprompts:/app/softprompts - ./training:/app/training + - ./cloudflared:/etc/cloudflared deploy: resources: reservations: diff --git a/docs/Chat-mode.md b/docs/Chat-mode.md index 08dd290..065e6a9 100644 --- a/docs/Chat-mode.md +++ b/docs/Chat-mode.md @@ -1,36 +1,30 @@ ## Chat characters -Custom chat mode characters are defined by `.yaml` files inside the `characters` folder. An example is included: [Example.yaml](https://github.com/oobabooga/text-generation-webui/blob/main/characters/Example.yaml) +Custom chat mode characters are defined by `.yaml` files inside the `characters` folder. An example is included: [Example.yaml](https://github.com/oobabooga/text-generation-webui/blob/main/characters/Example.yaml). The following fields may be defined: | Field | Description | |-------|-------------| | `name` or `bot` | The character's name. | +| `context` | A string that appears at the top of the prompt. It usually contains a description of the character's personality and a few example messages. | +| `greeting` (optional) | The character's opening message. It appears when the character is first loaded or when the history is cleared. | | `your_name` or `user` (optional) | Your name. This overwrites what you had previously written in the `Your name` field in the interface. | -| `context` | A string that appears at the top of the prompt. It usually contains a description of the character's personality. | -| `greeting` (optional) | The character's opening message when a new conversation is started. | -| `example_dialogue` (optional) | A few example messages to guide the model. | -| `turn_template` (optional) | Used to define where the spaces and new line characters should be in Instruct mode. See the characters in `characters/instruction-following` for examples. | #### Special tokens -* `{{char}}` or ``: are replaced with the character's name -* `{{user}}` or ``: are replaced with your name +The following replacements happen when the prompt is generated, and they apply to the `context` and `greeting` fields: -These replacements happen when the character is loaded, and they apply to the `context`, `greeting`, and `example_dialogue` fields. +* `{{char}}` and `` get replaced with the character's name. +* `{{user}}` and `` get replaced with your name. #### How do I add a profile picture for my character? -Put an image with the same name as your character's yaml file into the `characters` folder. For example, if your bot is `Character.yaml`, add `Character.jpg` or `Character.png` to the folder. +Put an image with the same name as your character's `.yaml` file into the `characters` folder. For example, if your bot is `Character.yaml`, add `Character.jpg` or `Character.png` to the folder. #### Is the chat history truncated in the prompt? -Once your prompt reaches the 2048 token limit, old messages will be removed one at a time. The context string will always stay at the top of the prompt and will never get truncated. - -#### Pygmalion format characters - -These are also supported out of the box. Simply put the JSON file in the `characters` folder, or upload it directly from the web UI by clicking on the "Upload character" tab at the bottom. +Once your prompt reaches the `truncation_length` parameter (2048 by default), old messages will be removed one at a time. The context string will always stay at the top of the prompt and will never get truncated. ## Chat styles diff --git a/docs/Extensions.md b/docs/Extensions.md index e156456..53acce5 100644 --- a/docs/Extensions.md +++ b/docs/Extensions.md @@ -1,45 +1,47 @@ -Extensions are defined by files named `script.py` inside subfolders of `text-generation-webui/extensions`. They are loaded at startup if specified with the `--extensions` flag. +# Extensions + +Extensions are defined by files named `script.py` inside subfolders of `text-generation-webui/extensions`. They are loaded at startup if the folder name is specified after the `--extensions` flag. For instance, `extensions/silero_tts/script.py` gets loaded with `python server.py --extensions silero_tts`. ## [text-generation-webui-extensions](https://github.com/oobabooga/text-generation-webui-extensions) -The link above contains a directory of user extensions for text-generation-webui. +The repository above contains a directory of user extensions. -If you create an extension, you are welcome to host it in a GitHub repository and submit it to the list above. +If you create an extension, you are welcome to host it in a GitHub repository and submit a PR adding it to the list. ## Built-in extensions -Most of these have been created by the extremely talented contributors that you can find here: [contributors](https://github.com/oobabooga/text-generation-webui/graphs/contributors?from=2022-12-18&to=&type=a). - |Extension|Description| |---------|-----------| -|[api](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/api)| Creates an API with two endpoints, one for streaming at `/api/v1/stream` port 5005 and another for blocking at `/api/v1/generate` port 5000. This is the main API for this web UI. | +|[api](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/api)| Creates an API with two endpoints, one for streaming at `/api/v1/stream` port 5005 and another for blocking at `/api/v1/generate` port 5000. This is the main API for the webui. | +|[openai](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai)| Creates an API that mimics the OpenAI API and can be used as a drop-in replacement. | +|[multimodal](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal) | Adds multimodality support (text+images). For a detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal/README.md) in the extension directory. | |[google_translate](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/google_translate)| Automatically translates inputs and outputs using Google Translate.| -|[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that biases the bot's responses in chat mode.| -|[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. | -|[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, it replaces the responses with an audio widget. | +|[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, responses are replaced with an audio widget. | |[elevenlabs_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/elevenlabs_tts)| Text-to-speech extension using the [ElevenLabs](https://beta.elevenlabs.io/) API. You need an API key to use it. | -|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. | |[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. | |[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). | -|[multimodal](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal) | Adds multimodality support (text+images). For a detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal/README.md) in the extension directory. | -|[openai](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai)| Creates an API that mimics the OpenAI API and can be used as a drop-in replacement. | +|[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that adds a hidden string at the beginning of the bot's reply in chat mode. | +|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. | +|[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. | |[superbooga](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superbooga)| An extension that uses ChromaDB to create an arbitrarily large pseudocontext, taking as input text files, URLs, or pasted text. Based on https://github.com/kaiokendev/superbig. | +|[ngrok](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/ngrok)| Allows you to access the web UI remotely using the ngrok reverse tunnel service (free). It's an alternative to the built-in Gradio `--share` feature. | +|[perplexity_colors](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/perplexity_colors)| Colors each token in the output text by its associated probability, as derived from the model logits. | ## How to write an extension -script.py may define the special functions and variables below. - -#### Predefined functions +The extensions framework is based on special functions and variables that you can define in `script.py`. The functions are the following: | Function | Description | |-------------|-------------| +| `def setup()` | Is executed when the extension gets imported. | | `def ui()` | Creates custom gradio elements when the UI is launched. | | `def custom_css()` | Returns custom CSS as a string. It is applied whenever the web UI is loaded. | | `def custom_js()` | Same as above but for javascript. | -| `def input_modifier(string, state)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. | -| `def output_modifier(string, state)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | +| `def input_modifier(string, state, is_chat=False)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. | +| `def output_modifier(string, state, is_chat=False)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | +| `def chat_input_modifier(text, visible_text, state)` | Modifies both the visible and internal inputs in chat mode. Can be used to hijack the chat input with custom content. | | `def bot_prefix_modifier(string, state)` | Applied in chat mode to the prefix for the bot's reply. | | `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. | | `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. | @@ -48,9 +50,7 @@ script.py may define the special functions and variables below. | `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See the `multimodal` extension for an example. | | `def custom_tokenized_length(prompt)` | Used in conjunction with `tokenizer_modifier`, returns the length in tokens of `prompt`. See the `multimodal` extension for an example. | -#### `params` dictionary - -In this dictionary, `display_name` is used to define the displayed name of the extension in the UI, and `is_tab` is used to define whether the extension should appear in a new tab. By default, extensions appear at the bottom of the "Text generation" tab. +Additionally, you can define a special `params` dictionary. In it, the `display_name` key is used to define the displayed name of the extension in the UI, and the `is_tab` key is used to define whether the extension should appear in a new tab. By default, extensions appear at the bottom of the "Text generation" tab. Example: @@ -61,7 +61,7 @@ params = { } ``` -Additionally, `params` may contain variables that you want to be customizable through a `settings.json` file. For instance, assuming the extension is in `extensions/google_translate`, the variable `language string` in +The `params` dict may also contain variables that you want to be customizable through a `settings.yaml` file. For instance, assuming the extension is in `extensions/google_translate`, the variable `language string` in ```python params = { @@ -71,32 +71,19 @@ params = { } ``` -can be customized by adding a key called `google_translate-language string` to `settings.json`: +can be customized by adding a key called `google_translate-language string` to `settings.yaml`: ```python -"google_translate-language string": "fr", +google_translate-language string: 'fr' ``` -That is, the syntax is `extension_name-variable_name`. - -#### `input_hijack` dictionary - -```python -input_hijack = { - 'state': False, - 'value': ["", ""] -} -``` -This is only used in chat mode. If your extension sets `input_hijack['state'] = True` at any moment, the next call to `modules.chat.chatbot_wrapper` will use the values inside `input_hijack['value']` as the user input for text generation. See the `send_pictures` extension above for an example. - -Additionally, your extension can set the value to be a callback in the form of `def cb(text: str, visible_text: str) -> [str, str]`. See the `multimodal` extension above for an example. +That is, the syntax for the key is `extension_name-variable_name`. ## Using multiple extensions at the same time -In order to use your extension, you must start the web UI with the `--extensions` flag followed by the name of your extension (the folder under `text-generation-webui/extension` where `script.py` resides). - -You can activate more than one extension at a time by providing their names separated by spaces. The input, output, and bot prefix modifiers will be applied in the specified order. +You can activate more than one extension at a time by providing their names separated by spaces after `--extensions`. The input, output, and bot prefix modifiers will be applied in the specified order. +Example: ``` python server.py --extensions enthusiasm translate # First apply enthusiasm, then translate @@ -106,56 +93,152 @@ python server.py --extensions translate enthusiasm # First apply translate, then Do note, that for: - `custom_generate_chat_prompt` - `custom_generate_reply` -- `tokenizer_modifier` - `custom_tokenized_length` only the first declaration encountered will be used and the rest will be ignored. -## The `bot_prefix_modifier` +## A full example -In chat mode, this function modifies the prefix for a new bot message. For instance, if your bot is named `Marie Antoinette`, the default prefix for a new message will be - -``` -Marie Antoinette: -``` - -Using `bot_prefix_modifier`, you can change it to: - -``` -Marie Antoinette: *I am very enthusiastic* -``` - -Marie Antoinette will become very enthusiastic in all her messages. - -## `custom_generate_reply` example - -Once defined in a `script.py`, this function is executed in place of the main generation functions. You can use it to connect the web UI to an external API, or to load a custom model that is not supported yet. - -Note that in chat mode, this function must only return the new text, whereas in other modes it must return the original prompt + the new text. +The source code below can be found at [extensions/example/script.py](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/example/script.py). ```python -import datetime +""" +An example of extension. It does nothing, but you can add transformations +before the return statements to customize the webui behavior. -def custom_generate_reply(question, original_question, seed, state, stopping_strings): - cumulative = '' - for i in range(10): - cumulative += f"Counting: {i}...\n" - yield cumulative +Starting from history_modifier and ending in output_modifier, the +functions are declared in the same order that they are called at +generation time. +""" - cumulative += f"Done! {str(datetime.datetime.now())}" - yield cumulative -``` +import gradio as gr +import torch +from transformers import LogitsProcessor -## `custom_generate_chat_prompt` example +from modules import chat, shared +from modules.text_generation import ( + decode, + encode, + generate_reply, +) -Below is an extension that just reproduces the default prompt generator in `modules/chat.py`. You can modify it freely to come up with your own prompts in chat mode. +params = { + "display_name": "Example Extension", + "is_tab": False, +} -```python -from modules import chat +class MyLogits(LogitsProcessor): + """ + Manipulates the probabilities for the next token before it gets sampled. + Used in the logits_processor_modifier function below. + """ + def __init__(self): + pass + + def __call__(self, input_ids, scores): + # probs = torch.softmax(scores, dim=-1, dtype=torch.float) + # probs[0] /= probs[0].sum() + # scores = torch.log(probs / (1 - probs)) + return scores + +def history_modifier(history): + """ + Modifies the chat history. + Only used in chat mode. + """ + return history + +def state_modifier(state): + """ + Modifies the state variable, which is a dictionary containing the input + values in the UI like sliders and checkboxes. + """ + return state + +def chat_input_modifier(text, visible_text, state): + """ + Modifies the user input string in chat mode (visible_text). + You can also modify the internal representation of the user + input (text) to change how it will appear in the prompt. + """ + return text, visible_text + +def input_modifier(string, state, is_chat=False): + """ + In default/notebook modes, modifies the whole prompt. + + In chat mode, it is the same as chat_input_modifier but only applied + to "text", here called "string", and not to "visible_text". + """ + return string + +def bot_prefix_modifier(string, state): + """ + Modifies the prefix for the next bot reply in chat mode. + By default, the prefix will be something like "Bot Name:". + """ + return string + +def tokenizer_modifier(state, prompt, input_ids, input_embeds): + """ + Modifies the input ids and embeds. + Used by the multimodal extension to put image embeddings in the prompt. + Only used by loaders that use the transformers library for sampling. + """ + return prompt, input_ids, input_embeds + +def logits_processor_modifier(processor_list, input_ids): + """ + Adds logits processors to the list, allowing you to access and modify + the next token probabilities. + Only used by loaders that use the transformers library for sampling. + """ + processor_list.append(MyLogits()) + return processor_list + +def output_modifier(string, state, is_chat=False): + """ + Modifies the LLM output before it gets presented. + + In chat mode, the modified version goes into history['visible'], + and the original version goes into history['internal']. + """ + return string def custom_generate_chat_prompt(user_input, state, **kwargs): - - # Do something with kwargs['history'] or state + """ + Replaces the function that generates the prompt from the chat history. + Only used in chat mode. + """ + result = chat.generate_chat_prompt(user_input, state, **kwargs) + return result - return chat.generate_chat_prompt(user_input, state, **kwargs) +def custom_css(): + """ + Returns a CSS string that gets appended to the CSS for the webui. + """ + return '' + +def custom_js(): + """ + Returns a javascript string that gets appended to the javascript + for the webui. + """ + return '' + +def setup(): + """ + Gets executed only once, when the extension is imported. + """ + pass + +def ui(): + """ + Gets executed when the UI is drawn. Custom gradio elements and + their corresponding event handlers should be defined here. + + To learn about gradio components, check out the docs: + https://gradio.app/docs/ + """ + pass ``` diff --git a/docs/FlexGen.md b/docs/FlexGen.md deleted file mode 100644 index 931cc36..0000000 --- a/docs/FlexGen.md +++ /dev/null @@ -1,64 +0,0 @@ ->FlexGen is a high-throughput generation engine for running large language models with limited GPU memory (e.g., a 16GB T4 GPU or a 24GB RTX3090 gaming card!). - -https://github.com/FMInference/FlexGen - -## Installation - -No additional installation steps are necessary. FlexGen is in the `requirements.txt` file for this project. - -## Converting a model - -FlexGen only works with the OPT model, and it needs to be converted to numpy format before starting the web UI: - -``` -python convert-to-flexgen.py models/opt-1.3b/ -``` - -The output will be saved to `models/opt-1.3b-np/`. - -## Usage - -The basic command is the following: - -``` -python server.py --model opt-1.3b --loader flexgen -``` - -For large models, the RAM usage may be too high and your computer may freeze. If that happens, you can try this: - -``` -python server.py --model opt-1.3b --loader flexgen --compress-weight -``` - -With this second command, I was able to run both OPT-6.7b and OPT-13B with **2GB VRAM**, and the speed was good in both cases. - -You can also manually set the offload strategy with - -``` -python server.py --model opt-1.3b --loader flexgen --percent 0 100 100 0 100 0 -``` - -where the six numbers after `--percent` are: - -``` -the percentage of weight on GPU -the percentage of weight on CPU -the percentage of attention cache on GPU -the percentage of attention cache on CPU -the percentage of activations on GPU -the percentage of activations on CPU -``` - -You should typically only change the first two numbers. If their sum is less than 100, the remaining layers will be offloaded to the disk, by default into the `text-generation-webui/cache` folder. - -## Performance - -In my experiments with OPT-30B using a RTX 3090 on Linux, I have obtained these results: - -* `--loader flexgen --compress-weight --percent 0 100 100 0 100 0`: 0.99 seconds per token. -* `--loader flexgen --compress-weight --percent 100 0 100 0 100 0`: 0.765 seconds per token. - -## Limitations - -* Only works with the OPT models. -* Only two generation parameters are available: `temperature` and `do_sample`. \ No newline at end of file diff --git a/docs/GPTQ-models-(4-bit-mode).md b/docs/GPTQ-models-(4-bit-mode).md index 63a6ed5..730e832 100644 --- a/docs/GPTQ-models-(4-bit-mode).md +++ b/docs/GPTQ-models-(4-bit-mode).md @@ -64,59 +64,19 @@ python server.py --autogptq --gpu-memory 3000MiB 6000MiB --model model_name ### Using LoRAs with AutoGPTQ -Not supported yet. +Works fine for a single LoRA. ## GPTQ-for-LLaMa GPTQ-for-LLaMa is the original adaptation of GPTQ for the LLaMA model. It was made possible by [@qwopqwop200](https://github.com/qwopqwop200/GPTQ-for-LLaMa): https://github.com/qwopqwop200/GPTQ-for-LLaMa -Different branches of GPTQ-for-LLaMa are currently available, including: - -| Branch | Comment | -|----|----| -| [Old CUDA branch (recommended)](https://github.com/oobabooga/GPTQ-for-LLaMa/) | The fastest branch, works on Windows and Linux. | -| [Up-to-date triton branch](https://github.com/qwopqwop200/GPTQ-for-LLaMa) | Slightly more precise than the old CUDA branch from 13b upwards, significantly more precise for 7b. 2x slower for small context size and only works on Linux. | -| [Up-to-date CUDA branch](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda) | As precise as the up-to-date triton branch, 10x slower than the old cuda branch for small context size. | - -Overall, I recommend using the old CUDA branch. It is included by default in the one-click-installer for this web UI. - -### Installation - -Start by cloning GPTQ-for-LLaMa into your `text-generation-webui/repositories` folder: - -``` -mkdir repositories -cd repositories -git clone https://github.com/oobabooga/GPTQ-for-LLaMa.git -b cuda -``` - -If you want to you to use the up-to-date CUDA or triton branches instead of the old CUDA branch, use these commands: - -``` -git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git -b cuda -``` - -``` -git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git -b triton -``` - -Next you need to install the CUDA extensions. You can do that either by installing the precompiled wheels, or by compiling the wheels yourself. +A Python package containing both major CUDA versions of GPTQ-for-LLaMa is used to simplify installation and compatibility: https://github.com/jllllll/GPTQ-for-LLaMa-CUDA ### Precompiled wheels -Kindly provided by our friend jllllll: https://github.com/jllllll/GPTQ-for-LLaMa-Wheels +Kindly provided by our friend jllllll: https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases -Windows: - -``` -pip install https://github.com/jllllll/GPTQ-for-LLaMa-Wheels/raw/main/quant_cuda-0.0.0-cp310-cp310-win_amd64.whl -``` - -Linux: - -``` -pip install https://github.com/jllllll/GPTQ-for-LLaMa-Wheels/raw/Linux-x64/quant_cuda-0.0.0-cp310-cp310-linux_x86_64.whl -``` +Wheels are included in requirements.txt and are installed with the webui on supported systems. ### Manual installation @@ -124,30 +84,42 @@ pip install https://github.com/jllllll/GPTQ-for-LLaMa-Wheels/raw/Linux-x64/quant ``` conda activate textgen -conda install -c conda-forge cudatoolkit-dev +conda install cuda -c nvidia/label/cuda-11.7.1 ``` The command above takes some 10 minutes to run and shows no progress bar or updates along the way. -You are also going to need to have a C++ compiler installed. On Linux, `sudo apt install build-essential` or equivalent is enough. +You are also going to need to have a C++ compiler installed. On Linux, `sudo apt install build-essential` or equivalent is enough. On Windows, Visual Studio or Visual Studio Build Tools is required. -If you're using an older version of CUDA toolkit (e.g. 11.7) but the latest version of `gcc` and `g++` (12.0+), you should downgrade with: `conda install -c conda-forge gxx==11.3.0`. Kernel compilation will fail otherwise. +If you're using an older version of CUDA toolkit (e.g. 11.7) but the latest version of `gcc` and `g++` (12.0+) on Linux, you should downgrade with: `conda install -c conda-forge gxx==11.3.0`. Kernel compilation will fail otherwise. #### Step 2: compile the CUDA extensions ``` -cd repositories/GPTQ-for-LLaMa -python setup_cuda.py install +python -m pip install git+https://github.com/jllllll/GPTQ-for-LLaMa-CUDA -v ``` ### Getting pre-converted LLaMA weights -These are models that you can simply download and place in your `models` folder. +* Direct download (recommended): -* Converted without `group-size` (better for the 7b model): https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483891617 -* Converted with `group-size` (better from 13b upwards): https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483941105 +https://huggingface.co/Neko-Institute-of-Science/LLaMA-7B-4bit-128g -⚠️ The tokenizer files in the sources above may be outdated. Make sure to obtain the universal LLaMA tokenizer as described [here](https://github.com/oobabooga/text-generation-webui/blob/main/docs/LLaMA-model.md#option-1-pre-converted-weights). +https://huggingface.co/Neko-Institute-of-Science/LLaMA-13B-4bit-128g + +https://huggingface.co/Neko-Institute-of-Science/LLaMA-30B-4bit-128g + +https://huggingface.co/Neko-Institute-of-Science/LLaMA-65B-4bit-128g + +These models were converted with `desc_act=True`. They work just fine with ExLlama. For AutoGPTQ, they will only work on Linux with the `triton` option checked. + +* Torrent: + +https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483891617 + +https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483941105 + +These models were converted with `desc_act=False`. As such, they are less accurate, but they work with AutoGPTQ on Windows. The `128g` versions are better from 13b upwards, and worse for 7b. The tokenizer files in the torrents are outdated, in particular the files called `tokenizer_config.json` and `special_tokens_map.json`. Here you can find those files: https://huggingface.co/oobabooga/llama-tokenizer ### Starting the web UI: @@ -191,22 +163,17 @@ This requires using a monkey patch that is supported by this web UI: https://git To use it: -1. Clone `johnsmith0031/alpaca_lora_4bit` into the repositories folder: +1. Install alpaca_lora_4bit using pip ``` -cd text-generation-webui/repositories -git clone https://github.com/johnsmith0031/alpaca_lora_4bit +git clone https://github.com/johnsmith0031/alpaca_lora_4bit.git +cd alpaca_lora_4bit +git fetch origin winglian-setup_pip +git checkout winglian-setup_pip +pip install . ``` -⚠️ I have tested it with the following commit specifically: `2f704b93c961bf202937b10aac9322b092afdce0` - -2. Install https://github.com/sterlind/GPTQ-for-LLaMa with this command: - -``` -pip install git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit -``` - -3. Start the UI with the `--monkey-patch` flag: +2. Start the UI with the `--monkey-patch` flag: ``` python server.py --model llama-7b-4bit-128g --listen --lora tloen_alpaca-lora-7b --monkey-patch diff --git a/docs/Generation-parameters.md b/docs/Generation-parameters.md deleted file mode 100644 index 4477421..0000000 --- a/docs/Generation-parameters.md +++ /dev/null @@ -1,35 +0,0 @@ -# Generation parameters - -For a description of the generation parameters provided by the transformers library, see this link: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig - -### llama.cpp - -llama.cpp only uses the following parameters: - -* temperature -* top_p -* top_k -* repetition_penalty -* tfs -* mirostat_mode -* mirostat_tau -* mirostat_eta - -### ExLlama - -ExLlama only uses the following parameters: - -* temperature -* top_p -* top_k -* repetition_penalty -* repetition_penalty_range -* typical_p - -### RWKV - -RWKV only uses the following parameters when loaded through the old .pth weights: - -* temperature -* top_p -* top_k diff --git a/docs/LLaMA-model.md b/docs/LLaMA-model.md index cd65526..ba7350f 100644 --- a/docs/LLaMA-model.md +++ b/docs/LLaMA-model.md @@ -9,10 +9,21 @@ This guide will cover usage through the official `transformers` implementation. ### Option 1: pre-converted weights -* Torrent: https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1484235789 -* Direct download: https://huggingface.co/Neko-Institute-of-Science +* Direct download (recommended): -⚠️ The tokenizers for the Torrent source above and also for many LLaMA fine-tunes available on Hugging Face may be outdated, in particular the files called `tokenizer_config.json` and `special_tokens_map.json`. Here you can find those files: https://huggingface.co/oobabooga/llama-tokenizer +https://huggingface.co/Neko-Institute-of-Science/LLaMA-7B-HF + +https://huggingface.co/Neko-Institute-of-Science/LLaMA-13B-HF + +https://huggingface.co/Neko-Institute-of-Science/LLaMA-30B-HF + +https://huggingface.co/Neko-Institute-of-Science/LLaMA-65B-HF + +* Torrent: + +https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1484235789 + +The tokenizer files in the torrent above are outdated, in particular the files called `tokenizer_config.json` and `special_tokens_map.json`. Here you can find those files: https://huggingface.co/oobabooga/llama-tokenizer ### Option 2: convert the weights yourself diff --git a/docs/LLaMA-v2-model.md b/docs/LLaMA-v2-model.md new file mode 100644 index 0000000..55c6aa7 --- /dev/null +++ b/docs/LLaMA-v2-model.md @@ -0,0 +1,35 @@ +# LLaMA-v2 + +To convert LLaMA-v2 from the `.pth` format provided by Meta to transformers format, follow the steps below: + +1) `cd` into your `llama` folder (the one containing `download.sh` and the models that you downloaded): + +``` +cd llama +``` + +2) Clone the transformers library: + +``` +git clone 'https://github.com/huggingface/transformers' + +``` + +3) Create symbolic links from the downloaded folders to names that the conversion script can recognize: + +``` +ln -s llama-2-7b 7B +ln -s llama-2-13b 13B +``` + +4) Do the conversions: + +``` +mkdir llama-2-7b-hf llama-2-13b-hf +python ./transformers/src/transformers/models/llama/convert_llama_weights_to_hf.py --input_dir . --model_size 7B --output_dir llama-2-7b-hf --safe_serialization true +python ./transformers/src/transformers/models/llama/convert_llama_weights_to_hf.py --input_dir . --model_size 13B --output_dir llama-2-13b-hf --safe_serialization true +``` + +5) Move the output folders inside `text-generation-webui/models` + +6) Have fun diff --git a/docs/README.md b/docs/README.md index 06b73b8..6ab8d21 100644 --- a/docs/README.md +++ b/docs/README.md @@ -8,11 +8,9 @@ * [Docker](Docker.md) * [ExLlama](ExLlama.md) * [Extensions](Extensions.md) -* [FlexGen](FlexGen.md) -* [Generation parameters](Generation-parameters.md) * [GPTQ models (4 bit mode)](GPTQ-models-(4-bit-mode).md) -* [llama.cpp models](llama.cpp-models.md) * [LLaMA model](LLaMA-model.md) +* [llama.cpp](llama.cpp.md) * [LoRA](LoRA.md) * [Low VRAM guide](Low-VRAM-guide.md) * [RWKV model](RWKV-model.md) diff --git a/docs/llama.cpp-models.md b/docs/llama.cpp-models.md deleted file mode 100644 index bcf3c04..0000000 --- a/docs/llama.cpp-models.md +++ /dev/null @@ -1,53 +0,0 @@ -# Using llama.cpp in the web UI - -## Setting up the models - -#### Pre-converted - -Place the model in the `models` folder, making sure that its name contains `ggml` somewhere and ends in `.bin`. - -#### Convert LLaMA yourself - -Follow the instructions in the llama.cpp README to generate the `ggml-model.bin` file: https://github.com/ggerganov/llama.cpp#usage - -## GPU acceleration - -Enabled with the `--n-gpu-layers` parameter. - -* If you have enough VRAM, use a high number like `--n-gpu-layers 200000` to offload all layers to the GPU. -* Otherwise, start with a low number like `--n-gpu-layers 10` and then gradually increase it until you run out of memory. - -To use this feature, you need to manually compile and install `llama-cpp-python` with GPU support. - -#### Linux - -``` -pip uninstall -y llama-cpp-python -CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir -``` - -#### Windows - -``` -pip uninstall -y llama-cpp-python -set CMAKE_ARGS="-DLLAMA_CUBLAS=on" -set FORCE_CMAKE=1 -pip install llama-cpp-python --no-cache-dir -``` - -#### macOS - -``` -pip uninstall -y llama-cpp-python -CMAKE_ARGS="-DLLAMA_METAL=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir -``` - -Here you can find the different compilation options for OpenBLAS / cuBLAS / CLBlast: https://pypi.org/project/llama-cpp-python/ - -## Performance - -This was the performance of llama-7b int4 on my i5-12400F (cpu only): - -> Output generated in 33.07 seconds (6.05 tokens/s, 200 tokens, context 17) - -You can change the number of threads with `--threads N`. diff --git a/docs/llama.cpp.md b/docs/llama.cpp.md new file mode 100644 index 0000000..48d60df --- /dev/null +++ b/docs/llama.cpp.md @@ -0,0 +1,43 @@ +# llama.cpp + +llama.cpp is the best backend in two important scenarios: + +1) You don't have a GPU. +2) You want to run a model that doesn't fit into your GPU. + +## Setting up the models + +#### Pre-converted + +Download the GGUF models directly into your `text-generation-webui/models` folder. It will be a single file. + +* Make sure its name ends in `.gguf`. +* `q4_K_M` quantization is recommended. + +#### Convert Llama yourself + +Follow the instructions in the llama.cpp README to generate a GGUF: https://github.com/ggerganov/llama.cpp#prepare-data--run + +## GPU acceleration + +Enabled with the `--n-gpu-layers` parameter. + +* If you have enough VRAM, use a high number like `--n-gpu-layers 1000` to offload all layers to the GPU. +* Otherwise, start with a low number like `--n-gpu-layers 10` and then gradually increase it until you run out of memory. + +This feature works out of the box for NVIDIA GPUs on Linux (amd64) or Windows. For other GPUs, you need to uninstall `llama-cpp-python` with + +``` +pip uninstall -y llama-cpp-python +``` + +and then recompile it using the commands here: https://pypi.org/project/llama-cpp-python/ + +#### macOS + +For macOS, these are the commands: + +``` +pip uninstall -y llama-cpp-python +CMAKE_ARGS="-DLLAMA_METAL=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir +``` diff --git a/download-model.py b/download-model.py index dedd5f6..d9b21d3 100644 --- a/download-model.py +++ b/download-model.py @@ -22,19 +22,31 @@ from requests.adapters import HTTPAdapter from tqdm.contrib.concurrent import thread_map +base = "https://huggingface.co" + + class ModelDownloader: - def __init__(self, max_retries = 5): - self.s = requests.Session() + def __init__(self, max_retries=5): + self.session = requests.Session() if max_retries: - self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries)) - self.s.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries)) + self.session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries)) + self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries)) if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None: - self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) + self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) + if os.getenv('HF_TOKEN') is not None: + self.session.headers = {'authorization': f'Bearer {os.getenv("HF_TOKEN")}'} def sanitize_model_and_branch_names(self, model, branch): if model[-1] == '/': model = model[:-1] + if model.startswith(base + '/'): + model = model[len(base) + 1:] + + model_parts = model.split(":") + model = model_parts[0] if len(model_parts) > 0 else model + branch = model_parts[1] if len(model_parts) > 1 else branch + if branch is None: branch = "main" else: @@ -45,8 +57,7 @@ class ModelDownloader: return model, branch - def get_download_links_from_huggingface(self, model, branch, text_only=False): - base = "https://huggingface.co" + def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None): page = f"/api/models/{model}/tree/{branch}" cursor = b"" @@ -55,12 +66,12 @@ class ModelDownloader: classifications = [] has_pytorch = False has_pt = False - # has_ggml = False + has_gguf = False has_safetensors = False is_lora = False while True: url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "") - r = self.s.get(url, timeout=20) + r = self.session.get(url, timeout=10) r.raise_for_status() content = r.content @@ -70,16 +81,19 @@ class ModelDownloader: for i in range(len(dict)): fname = dict[i]['path'] + if specific_file not in [None, ''] and fname != specific_file: + continue + if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')): is_lora = True - is_pytorch = re.match("(pytorch|adapter|gptq)_model.*\.bin", fname) - is_safetensors = re.match(".*\.safetensors", fname) - is_pt = re.match(".*\.pt", fname) - is_ggml = re.match(".*ggml.*\.bin", fname) - is_tokenizer = re.match("(tokenizer|ice).*\.model", fname) - is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer - if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)): + is_pytorch = re.match(r"(pytorch|adapter|gptq)_model.*\.bin", fname) + is_safetensors = re.match(r".*\.safetensors", fname) + is_pt = re.match(r".*\.pt", fname) + is_gguf = re.match(r'.*\.gguf', fname) + is_tokenizer = re.match(r"(tokenizer|ice|spiece).*\.model", fname) + is_text = re.match(r".*\.(txt|json|py|md)", fname) or is_tokenizer + if any((is_pytorch, is_safetensors, is_pt, is_gguf, is_tokenizer, is_text)): if 'lfs' in dict[i]: sha256.append([fname, dict[i]['lfs']['oid']]) @@ -99,9 +113,9 @@ class ModelDownloader: elif is_pt: has_pt = True classifications.append('pt') - elif is_ggml: - # has_ggml = True - classifications.append('ggml') + elif is_gguf: + has_gguf = True + classifications.append('gguf') cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' cursor = base64.b64encode(cursor) @@ -113,12 +127,17 @@ class ModelDownloader: if classifications[i] in ['pytorch', 'pt']: links.pop(i) - return links, sha256, is_lora + is_llamacpp = has_gguf and specific_file is not None + return links, sha256, is_lora, is_llamacpp - def get_output_folder(self, model, branch, is_lora, base_folder=None): + def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, base_folder=None): if base_folder is None: base_folder = 'models' if not is_lora else 'loras' + # If the model is of type GGUF, save directly in the base_folder + if is_llamacpp: + return Path(base_folder) + output_folder = f"{'_'.join(model.split('/')[-2:])}" if branch != 'main': output_folder += f'_{branch}' @@ -134,7 +153,7 @@ class ModelDownloader: if output_path.exists() and not start_from_scratch: # Check if the file has already been downloaded completely - r = self.s.get(url, stream=True, timeout=20) + r = self.session.get(url, stream=True, timeout=10) total_size = int(r.headers.get('content-length', 0)) if output_path.stat().st_size >= total_size: return @@ -143,7 +162,7 @@ class ModelDownloader: headers = {'Range': f'bytes={output_path.stat().st_size}-'} mode = 'ab' - with self.s.get(url, stream=True, headers=headers, timeout=20) as r: + with self.session.get(url, stream=True, headers=headers, timeout=10) as r: r.raise_for_status() # Do not continue the download if the request was unsuccessful total_size = int(r.headers.get('content-length', 0)) block_size = 1024 * 1024 # 1MB @@ -155,29 +174,34 @@ class ModelDownloader: f.write(data) if total_size != 0 and self.progress_bar is not None: count += len(data) - self.progress_bar(float(count) / float(total_size), f"Downloading {filename}") + self.progress_bar(float(count) / float(total_size), f"{filename}") def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1): thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True) - def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=1): + def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=1, specific_file=None, is_llamacpp=False): self.progress_bar = progress_bar - # Creating the folder and writing the metadata + # Create the folder and writing the metadata output_folder.mkdir(parents=True, exist_ok=True) - metadata = f'url: https://huggingface.co/{model}\n' \ - f'branch: {branch}\n' \ - f'download date: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\n' - sha256_str = '\n'.join([f' {item[1]} {item[0]}' for item in sha256]) - if sha256_str: - metadata += f'sha256sum:\n{sha256_str}' + if not is_llamacpp: + metadata = f'url: https://huggingface.co/{model}\n' \ + f'branch: {branch}\n' \ + f'download date: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\n' - metadata += '\n' - (output_folder / 'huggingface-metadata.txt').write_text(metadata) + sha256_str = '\n'.join([f' {item[1]} {item[0]}' for item in sha256]) + if sha256_str: + metadata += f'sha256sum:\n{sha256_str}' + + metadata += '\n' + (output_folder / 'huggingface-metadata.txt').write_text(metadata) + + if specific_file: + print(f"Downloading {specific_file} to {output_folder}") + else: + print(f"Downloading the model to {output_folder}") - # Downloading the files - print(f"Downloading the model to {output_folder}") self.start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads) def check_model_files(self, model, branch, links, sha256, output_folder): @@ -213,6 +237,7 @@ if __name__ == '__main__': parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') + parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).') parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') @@ -221,28 +246,29 @@ if __name__ == '__main__': branch = args.branch model = args.MODEL + specific_file = args.specific_file if model is None: print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').") sys.exit() downloader = ModelDownloader(max_retries=args.max_retries) - # Cleaning up the model/branch names + # Clean up the model/branch names try: model, branch = downloader.sanitize_model_and_branch_names(model, branch) except ValueError as err_branch: print(f"Error: {err_branch}") sys.exit() - # Getting the download links from Hugging Face - links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only) + # Get the download links from Hugging Face + links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only, specific_file=specific_file) - # Getting the output folder - output_folder = downloader.get_output_folder(model, branch, is_lora, base_folder=args.output) + # Get the output folder + output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, base_folder=args.output) if args.check: # Check previously downloaded files downloader.check_model_files(model, branch, links, sha256, output_folder) else: # Download files - downloader.download_model_files(model, branch, links, sha256, output_folder, threads=args.threads) + downloader.download_model_files(model, branch, links, sha256, output_folder, specific_file=specific_file, threads=args.threads, is_llamacpp=is_llamacpp) diff --git a/extensions/Training_PRO/custom_scheduler.py b/extensions/Training_PRO/custom_scheduler.py new file mode 100644 index 0000000..1f1597a --- /dev/null +++ b/extensions/Training_PRO/custom_scheduler.py @@ -0,0 +1,96 @@ +from functools import partial +import torch +import transformers +import math +from torch.optim.lr_scheduler import LambdaLR + + +#FPHAM custom training scheduller block - should be extracted to separate file +last_print_label = '' + +def _get_fp_cosine_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int): + + global last_print_label + print_label = '' + + num_warmup_steps = min(num_warmup_steps,num_firstepoch_steps) + + if current_step < num_warmup_steps: + print_label = 'Scheduler: Warmup' + elif current_step < num_firstepoch_steps: + print_label = 'Scheduler: Hold' + else: + print_label = 'Scheduler: Annealing' + + if print_label != last_print_label: + print(print_label) + + last_print_label = print_label + + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + if current_step < num_firstepoch_steps: + return 1.0 + + progress = float(current_step - num_firstepoch_steps) / float(max(1, num_training_steps - num_firstepoch_steps)) + num_cycles = 0.5 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + +def custom_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1): + """ + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_fp_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_firstepoch_steps = num_firstepoch_steps, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + +class FPSchedulerTrainer(transformers.Trainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): + #Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or passed as an argument. + + if self.args.lr_scheduler_type == 'cosine': + num_train_epochs = self.args.num_train_epochs + num_warmup_steps=self.args.get_warmup_steps(num_training_steps) + num_firstepoch_steps = math.ceil(num_training_steps/num_train_epochs) + num_warmup_acc = num_warmup_steps*self.args.gradient_accumulation_steps + num_firstepoch_steps_acc = num_firstepoch_steps*self.args.gradient_accumulation_steps + num_training_steps_acc = num_training_steps*self.args.gradient_accumulation_steps + num_warmup_acc_min = min(num_warmup_acc, num_firstepoch_steps_acc) + + if num_warmup_acc>num_firstepoch_steps_acc: + print(f"\033[1;31;1mWARNING: The number of warmup steps is set too high! It will be clamped to 1 epoch, essentially going from warmup to annealing.\033[0;37;0m") + print (f"FP Scheduler Warmup: 0-[{num_warmup_acc_min}], Hold [{num_warmup_acc_min}]-{num_firstepoch_steps_acc}, Annealing {num_firstepoch_steps_acc}-{num_training_steps_acc}") + else: + print (f"FP Scheduler Warmup: 0-{num_warmup_acc_min}, Hold {num_warmup_acc_min}-{num_firstepoch_steps_acc}, Annealing {num_firstepoch_steps_acc}-{num_training_steps_acc}") + + self.lr_scheduler = custom_scheduler_with_warmup( + optimizer=self.optimizer if optimizer is None else optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_firstepoch_steps = num_firstepoch_steps, + ) + self._created_lr_scheduler = True + return self.lr_scheduler + else: + return super().create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) \ No newline at end of file diff --git a/extensions/Training_PRO/matplotgraph.py b/extensions/Training_PRO/matplotgraph.py new file mode 100644 index 0000000..348fc01 --- /dev/null +++ b/extensions/Training_PRO/matplotgraph.py @@ -0,0 +1,62 @@ +import os +import json + +def create_graph(lora_path, lora_name): + try: + import matplotlib.pyplot as plt + from matplotlib.ticker import ScalarFormatter + + peft_model_path = f'{lora_path}/training_graph.json' + image_model_path = f'{lora_path}/training_graph.png' + # Check if the JSON file exists + if os.path.exists(peft_model_path): + # Load data from JSON file + with open(peft_model_path, 'r') as file: + data = json.load(file) + # Extract x, y1, and y2 values + x = [item['epoch'] for item in data] + y1 = [item['learning_rate'] for item in data] + y2 = [item['loss'] for item in data] + + # Create the line chart + fig, ax1 = plt.subplots(figsize=(10, 6)) + + + # Plot y1 (learning rate) on the first y-axis + ax1.plot(x, y1, 'b-', label='Learning Rate') + ax1.set_xlabel('Epoch') + ax1.set_ylabel('Learning Rate', color='b') + ax1.tick_params('y', colors='b') + + # Create a second y-axis + ax2 = ax1.twinx() + + # Plot y2 (loss) on the second y-axis + ax2.plot(x, y2, 'r-', label='Loss') + ax2.set_ylabel('Loss', color='r') + ax2.tick_params('y', colors='r') + + # Set the y-axis formatter to display numbers in scientific notation + ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True)) + ax1.ticklabel_format(style='sci', axis='y', scilimits=(0,0)) + + # Add grid + ax1.grid(True) + + # Combine the legends for both plots + lines, labels = ax1.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax2.legend(lines + lines2, labels + labels2, loc='best') + + # Set the title + plt.title(f'{lora_name} LR and Loss vs Epoch') + + # Save the chart as an image + plt.savefig(image_model_path) + + print(f"Graph saved in {image_model_path}") + else: + print(f"File 'training_graph.json' does not exist in the {lora_path}") + + except ImportError: + print("matplotlib is not installed. Please install matplotlib to create PNG graphs") \ No newline at end of file diff --git a/extensions/Training_PRO/readme.md b/extensions/Training_PRO/readme.md new file mode 100644 index 0000000..c0647db --- /dev/null +++ b/extensions/Training_PRO/readme.md @@ -0,0 +1,27 @@ +This is an expanded Training tab + + +- Chunking: precise raw text slicer (PRTS) uses sentence slicing and making sure things are clean on all ends +- overlap chunking - this special overlapping will make additional overlap block based on logical rules (aka no overlap block on hard cut) +- custom scheduler (follow the code to make your own) In LR Scheduler select FP_low_epoch_annealing - this scheduler will keep the LR constant for first epoch then use cosine for the rest - this part would be best to spawn into a new py file +- save loss threshold - will not save the "Save every n steps" checkpoints until this threshold is reached (I definitely don't need multiple checkpoints that are 2.5 loss - I'm usually interested in checkpoints between say 1.5 and 1.9 loss) +- saves graph png file at the end with learning rate and loss per epoch +- adding EOS to each block or to hard cut only +- automatically lowers gradient accumulation if you go overboard and set gradient accumulation that will be higher than actual data - transformers would then throw error (or they used to, not sure if still true) but in any way, it will fix bad data +- turn BOS on and OFF +- target selector + +###Notes: + +This uses it's own chunking code for raw text based on sentence splitting. This will avoid weird cuts in the chunks and each chunk should now start with sentence and end on some sentence. It works hand in hand with Hard Cut. +A propper use is to structure your text into logical blocks (ideas) separated by three \n then use three \n in hard cut. +This way each chunk will contain only one flow of ideas and not derail in the thoughts. +And Overlapping code will create overlapped blocks on sentence basis too, but not cross hard cut, thus not cross different ideas either. +Does it make any sense? No? Hmmmm... + +###Targets + +Normal LORA is q, v and that's what you should use. +You can use (q k v o) or (q k v) and it will give you a lot more trainable parameters. The benefit is that you can keep rank lower and still attain the same coherency as q v with high rank. Guanaco has been trained with QLORA and q k v o for example and they swear by it. +I also added k-v-down which is lifted from IA3, which is very odd one to use for LORA, but it created adorable style craziness when training on raw structured text and bringing the loss all the way down to 1.1 . It didn't overfit (q-v would be just writing entire novels at loss 1.1) and it followed the instruction seeping from the previous fine-tuning. YMMW of course. +Using All will train all 7 targets q-k-v-o-up,down, gate - not sure if there is much benefit from attention only qkvo. It sure makes LORA huge. If that's what you like. diff --git a/extensions/Training_PRO/script.py b/extensions/Training_PRO/script.py new file mode 100644 index 0000000..d5b0964 --- /dev/null +++ b/extensions/Training_PRO/script.py @@ -0,0 +1,810 @@ +import os + +os.environ["WANDB_MODE"] = "offline" +# os.environ["WANDB_DISABLED"] = "true" + +import json +import math +import random +import shutil +import sys +import threading +import time +import traceback +from datetime import datetime +from pathlib import Path + +import gradio as gr +import torch +import transformers + +from .custom_scheduler import FPSchedulerTrainer +from .matplotgraph import create_graph +from .train_utils import get_available_loras_local, precise_cut + +from datasets import Dataset, load_dataset +from peft import ( + LoraConfig, + get_peft_model, + prepare_model_for_kbit_training, + set_peft_model_state_dict +) +from peft.utils.other import \ + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as model_to_lora_modules +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES +) + +from modules import shared, utils +from modules.ui import create_refresh_button + +from modules.evaluate import ( + calculate_perplexity, + generate_markdown_table, + save_past_evaluations +) +from modules.logging_colors import logger +from modules.models import reload_model +from modules.utils import natural_keys + + +params = { + "display_name": "Training PRO", + "is_tab": True +} + +MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()} +PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to", "precize_slicing_overlap", "add_eos_token_type", "save_steps_under_loss", "add_bos_token", "training_projection"] +WANT_INTERRUPT = False + +train_log = {} +train_template = {} +train_log_graph = [] +Lora_sortedByTime = False +train_choices = ["all","q-k-v-o","q-k-v","k-v-down","q-v"] + + + +def ui(): + with gr.Tab('Train LoRA', elem_id='lora-train-tab'): + tmp = gr.State('') + with gr.Row(): + with gr.Column(): + gr.Markdown("This is enhanced version of Lora Training with a sentence based RAW text chunking code") + + with gr.Row(): + with gr.Column(scale=5): + with gr.Row(): + copy_from = gr.Dropdown(label='Copy parameters from', value='None', choices=get_available_loras_local(Lora_sortedByTime), elem_classes=['slim-dropdown']) + create_refresh_button(copy_from, lambda: None, lambda: {'choices': get_available_loras_local(Lora_sortedByTime)}, 'refresh-button') + with gr.Column(): + sort_byTime = gr.Checkbox(label='Sort list by Date', value=False, info='Sorts Loras by date created.', elem_classes=['no-background']) + + with gr.Row(): + with gr.Column(scale=5): + lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file') + + with gr.Column(): + always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background']) + + with gr.Row(): + with gr.Column(): + lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.') + lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.') + batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.') + micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.') + cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.') + + with gr.Column(): + save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.') + save_steps_under_loss = gr.Slider(label='Save Loss Threshold', value=1.9, minimum=0.0, maximum=3.0, step=0.1, info='Save checkpoints only if the loss is less or equall Threshold loss. (0 = save all)') + epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.') + learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='In scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.') + lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt', 'FP_low_epoch_annealing'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.', elem_classes=['slim-dropdown']) + + with gr.Accordion(label='Advanced Options', open=True): + with gr.Row(): + with gr.Column(): + lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.') + stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)') + training_projection = gr.Radio(value = train_choices[4], label='LLaMA Target Projections', info='Change the targets (LORA is typically q-v)', choices=train_choices) + + optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.', elem_classes=['slim-dropdown']) + + with gr.Column(): + warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.') + train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.') + add_bos_token = gr.Checkbox(label='Add BOS token', value=True, info="Adds BOS token for each dataset item") + add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each dataset item") + add_eos_token_type = gr.Dropdown(label='EOS placement (raw text)', choices=['Every Block', 'Hard Cut Blocks Only'], value='Every Block', info='', allow_custom_value = False) + + higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.') + report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True) + + with gr.Column(): + with gr.Tab(label='Formatted Dataset'): + with gr.Row(): + format = gr.Dropdown(choices=utils.get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.', elem_classes=['slim-dropdown']) + create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('training/formats', 'json')}, 'refresh-button') + + with gr.Row(): + dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.', elem_classes=['slim-dropdown']) + create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button') + + with gr.Row(): + eval_dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown']) + create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button') + + eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.') + + with gr.Tab(label="Raw text file"): + with gr.Row(): + raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.', elem_classes=['slim-dropdown']) + create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button') + + with gr.Row(): + with gr.Column(): + precize_slicing_overlap = gr.Checkbox(label='Create Overlapping blocks', value = True) + with gr.Column(): + hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a cut between logical blocks of text (ex. Ideas or Chapters). Helps prevent unwanted overlap between unrelated ideas.') + min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Text blocks that have less or equal characters than this number.') + + with gr.Row(): + start_button = gr.Button("Start LoRA Training", variant='primary') + stop_button = gr.Button("Interrupt") + + output = gr.Markdown(value="Ready") + + with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'): + with gr.Row(): + with gr.Column(): + models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True) + evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.') + with gr.Row(): + with gr.Column(): + stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.') + + with gr.Column(): + max_length = gr.Slider(label='max_length', minimum=0, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.') + + with gr.Row(): + start_current_evaluation = gr.Button("Evaluate loaded model") + start_evaluation = gr.Button("Evaluate selected models") + stop_evaluation = gr.Button("Interrupt") + + with gr.Column(): + evaluation_log = gr.Markdown(value='') + + evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True) + with gr.Row(): + save_comments = gr.Button('Save comments', elem_classes="small-button") + refresh_table = gr.Button('Refresh the table', elem_classes="small-button") + + # Training events + all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to, precize_slicing_overlap, add_eos_token_type, save_steps_under_loss, add_bos_token, training_projection] + + copy_from.change(do_copy_params, [copy_from] + all_params, all_params) + start_button.click(do_train, all_params, output) + stop_button.click(do_interrupt, None, None, queue=False) + higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha]) + + # Evaluation events. For some reason, the interrupt event + # doesn't work with the .then() syntax, so I write them one + # by one in this ugly but functional way. + ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False) + start_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False) + + start_current_evaluation.click(lambda: ['current model'], None, tmp) + ev_cur = start_current_evaluation.click(calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False) + start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False) + + stop_evaluation.click(None, None, None, cancels=[ev, ev_cur], queue=False) + refresh_table.click(generate_markdown_table, None, evaluation_table, show_progress=True) + save_comments.click( + save_past_evaluations, evaluation_table, None).then( + lambda: "Comments saved.", None, evaluation_log, show_progress=False) + + def reload_lora(): + global Lora_sortedByTime + return gr.Dropdown.update(choices=get_available_loras_local(Lora_sortedByTime)) + + def global_lora_time(sort_byTime): + global Lora_sortedByTime + Lora_sortedByTime = sort_byTime + + + sort_byTime.change(global_lora_time, sort_byTime, None).then(reload_lora,None,copy_from) + + +def do_interrupt(): + global WANT_INTERRUPT + WANT_INTERRUPT = True + + +def do_copy_params(lora_name: str, *args): + f_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json" + if Path(f_name).is_file(): + with open(f_name, 'r', encoding='utf-8') as format_file: + params: dict[str, str] = json.load(format_file) + else: + params = {} + + result = list() + for i in range(0, len(PARAMETERS)): + key = PARAMETERS[i] + if key in params: + result.append(params[key]) + else: + result.append(args[i]) + + return result + + +def change_rank_limit(use_higher_ranks: bool): + mult = 2 if use_higher_ranks else 1 + return {"maximum": 1024 * mult, "__type__": "update"}, {"maximum": 2048 * mult, "__type__": "update"} + + +def clean_path(base_path: str, path: str): + """Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" + path = path.replace('\\', '/').replace('..', '_') + if base_path is None: + return path + + return f'{Path(base_path).absolute()}/{path}' + + +def backup_adapter(input_folder): + # Get the creation date of the file adapter_model.bin + try: + adapter_file = Path(f"{input_folder}/adapter_model.bin") + if adapter_file.is_file(): + + logger.info("Backing up existing LoRA adapter...") + creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime) + creation_date_str = creation_date.strftime("Backup-%Y-%m-%d") + + # Create the new subfolder + subfolder_path = Path(f"{input_folder}/{creation_date_str}") + subfolder_path.mkdir(parents=True, exist_ok=True) + + # Check if the file already exists in the subfolder + backup_adapter_file = Path(f"{input_folder}/{creation_date_str}/adapter_model.bin") + if backup_adapter_file.is_file(): + print(" - Backup already exists. Skipping backup process.") + return + + # Copy existing files to the new subfolder + existing_files = Path(input_folder).iterdir() + for file in existing_files: + if file.is_file(): + shutil.copy2(file, subfolder_path) + except Exception as e: + print("An error occurred in backup_adapter:", str(e)) + + +def calc_trainable_parameters(model): + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + return trainable_params, all_param + + +def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str, precize_slicing_overlap: bool, add_eos_token_type: str, save_steps_under_loss: float, add_bos_token: bool, training_projection: str): + + if shared.args.monkey_patch: + from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( + replace_peft_model_with_int4_lora_model + ) + replace_peft_model_with_int4_lora_model() + + global WANT_INTERRUPT + WANT_INTERRUPT = False + + # == Input validation / processing == + yield "Preparing the input..." + lora_file_path = clean_path(None, lora_name) + if lora_file_path.strip() == '': + yield "Missing or invalid LoRA file name input." + return + + lora_file_path = f"{Path(shared.args.lora_dir)}/{lora_file_path}" + actual_lr = float(learning_rate) + model_type = type(shared.model).__name__ + + if model_type in MODEL_CLASSES: + model_id = MODEL_CLASSES[model_type] + else: + model_id = "llama" + if model_type == "PeftModelForCausalLM": + if len(shared.lora_names) > 0: + yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" + logger.warning("Training LoRA over top of another LoRA. May have unexpected effects.") + else: + yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" + logger.warning("Model ID not matched due to LoRA loading. Consider reloading base model.") + else: + yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" + logger.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})") + + time.sleep(5) + + if shared.args.loader == 'GPTQ-for-LLaMa' and not shared.args.monkey_patch: + yield "LoRA training with GPTQ-for-LLaMa requires loading with `--monkey-patch`" + return + + if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0: + yield "Cannot input zeroes." + return + + gradient_accumulation_steps = batch_size // micro_batch_size + shared.tokenizer.pad_token_id = 0 + shared.tokenizer.padding_side = "left" + + def encode(text, prepend_bos_token): + + result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len) + # Check if the first two tokens are BOS + if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]: + result = result[1:] + + if not prepend_bos_token and result[0] == shared.tokenizer.bos_token_id: + result = result[1:] + return result + + def tokenize(prompt, append_eos_token=False, prepend_bos_token = False): + + if train_only_after == '' or train_only_after not in prompt: + input_ids = encode(prompt, prepend_bos_token) + + if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len: + input_ids.append(shared.tokenizer.eos_token_id) + + input_ids = [shared.tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids + + labels = [1] * len(input_ids) + else: + ind = prompt.index(train_only_after) + len(train_only_after) + before_tokens = encode(prompt[:ind], prepend_bos_token) + after_tokens = encode(prompt[ind:], False) + + if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id: + after_tokens.append(shared.tokenizer.eos_token_id) + + full_length = len(after_tokens) + len(before_tokens) + if full_length > cutoff_len: + after_tokens = after_tokens[:cutoff_len - len(before_tokens)] + else: + before_tokens = [shared.tokenizer.pad_token_id] * (cutoff_len - full_length) + before_tokens + + input_ids = before_tokens + after_tokens + labels = [-100] * len(before_tokens) + [1] * len(after_tokens) + + input_ids = torch.tensor(input_ids) + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": input_ids.ne(shared.tokenizer.pad_token_id), + } + + train_template.clear() + + + + print(f"*** LoRA: {lora_name} ***") + + # END OF FPHAM SENTENCE SPLIT functions =================== + + # == Prep the dataset, format, etc == + if raw_text_file not in ['None', '']: + train_template["template_type"] = "raw_text" + logger.info("Loading raw text file dataset...") + fullpath = clean_path('training/datasets', f'{raw_text_file}') + fullpath = Path(fullpath) + if fullpath.is_dir(): + logger.info('Training path directory {}'.format(raw_text_file)) + raw_text = "" + file_paths = sorted(fullpath.glob('*.txt'), key=lambda path: natural_keys(path.name)) + for file_path in file_paths: + if file_path.is_file(): + with file_path.open('r', encoding='utf-8') as file: + raw_text += file.read().replace('\r', '') + + logger.info(f"Loaded training file: {file_path.name}") + else: + with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file: + raw_text = file.read().replace('\r', '') + + # FPHAM PRECISE SLICING + if min_chars<0: + min_chars = 0 + + add_EOS_to_all = add_eos_token and add_eos_token_type == 'Every Block' + add_EOS_to_HC = add_eos_token and add_eos_token_type != 'Every Block' + + #print (f"add_eos_token {add_eos_token}, add_EOS_to_all {add_EOS_to_all}, add_EOS_to_HC {add_EOS_to_HC}") + + # == New more precise slicing on sentence boundary == + text_chunks = precise_cut(raw_text, precize_slicing_overlap, min_chars, add_EOS_to_HC, cutoff_len, hard_cut_string) + train_data = Dataset.from_list([tokenize(x, add_EOS_to_all, add_bos_token) for x in text_chunks]) + if add_EOS_to_all: + print(f"Added EOS to {len(text_chunks)} blocks") + + del text_chunks + eval_data = None + else: + if dataset in ['None', '']: + yield "Missing dataset choice input, cannot continue." + return + + if format in ['None', '']: + yield "Missing format choice input, cannot continue." + return + + train_template["template_type"] = "dataset" + + with open(clean_path('training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile: + format_data: dict[str, str] = json.load(formatFile) + + # == store training prompt == + for _, value in format_data.items(): + prompt_key = f"template_{len(train_template)}" + train_template[prompt_key] = value + + def generate_prompt(data_point: dict[str, str]): + for options, data in format_data.items(): + if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)): + for key, val in data_point.items(): + if type(val) is str: + data = data.replace(f'%{key}%', val) + return data + raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"') + + def generate_and_tokenize_prompt(data_point): + prompt = generate_prompt(data_point) + return tokenize(prompt, add_eos_token, add_bos_token) + + logger.info("Loading JSON datasets...") + data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json')) + train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30)) + + print(f"BOS: {add_bos_token} EOS: {add_eos_token}") + + if eval_dataset == 'None': + eval_data = None + else: + eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json')) + eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30)) + + # == We MUST reload model if it went through any previous training, even failed one == + if shared.model_dirty_from_training: + selected_model = shared.model_name + if selected_model: + print("\033[1;31;1m(Model has been modified by previous training, it needs to be reloaded...)\033[0;37;0m") + try: + yield f"Reloading {selected_model}..." + reload_model() + if shared.model is not None: + print("Model reloaded OK, continue with training.") + else: + return f"Failed to load {selected_model}." + except: + exc = traceback.format_exc() + logger.error('Failed to reload the model.') + print(exc) + return exc.replace('\n', '\n\n') + + # == Start prepping the model itself == + if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): + logger.info("Getting model ready...") + prepare_model_for_kbit_training(shared.model) + + # base model is now frozen and should not be reused for any other LoRA training than this one + shared.model_dirty_from_training = True + if training_projection==train_choices[0]: + model_to_lora_modules["llama"] = ["gate_proj","down_proj","up_proj","q_proj","k_proj","v_proj","o_proj"] + elif training_projection==train_choices[1]: + model_to_lora_modules["llama"] = ["q_proj","k_proj", "v_proj", "o_proj"] + elif training_projection==train_choices[2]: + model_to_lora_modules["llama"] = ["q_proj","k_proj", "v_proj"] + elif training_projection==train_choices[3]: + model_to_lora_modules["llama"] = ["k_proj", "v_proj", "down_proj"] + else: + model_to_lora_modules["llama"] = ["q_proj", "v_proj"] + + + logger.info("Preparing for training...") + config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + target_modules=model_to_lora_modules[model_id], + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM" + ) + + # == Backup the existing adapter == + if not always_override: + backup_adapter(lora_file_path) + + # == get model trainable params + model_trainable_params, model_all_params = calc_trainable_parameters(shared.model) + + try: + logger.info("Creating LoRA model...") + lora_model = get_peft_model(shared.model, config) + if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file(): + logger.info("Loading existing LoRA data...") + state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin") + set_peft_model_state_dict(lora_model, state_dict_peft) + except: + yield traceback.format_exc().replace('\n', '\n\n') + return + + if shared.args.monkey_patch: + from alpaca_lora_4bit.autograd_4bit import Autograd4bitQuantLinear + from alpaca_lora_4bit.models import Linear4bitLt + for _, m in lora_model.named_modules(): + if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt): + if m.is_v1_model: + m.zeros = m.zeros.half() + m.scales = m.scales.half() + + class Tracked(): + def __init__(self): + self.current_steps = 0 + self.max_steps = 0 + self.did_save = False + + tracked = Tracked() + actual_save_steps = math.ceil(save_steps / gradient_accumulation_steps) + + class Callbacks(transformers.TrainerCallback): + def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): + tracked.current_steps = state.global_step * gradient_accumulation_steps + tracked.max_steps = state.max_steps * gradient_accumulation_steps + if WANT_INTERRUPT: + control.should_epoch_stop = True + control.should_training_stop = True + elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0: + current_loss = float(train_log.get('loss', 0.0)) + if current_loss <= save_steps_under_loss or save_steps_under_loss==0.0: + lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/") + print(f"\033[1;30;40mStep: {tracked.current_steps:6} \033[0;37;0m Checkpoint-{tracked.current_steps} saved") + # Save log + with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_log.json", 'w', encoding='utf-8') as file: + json.dump(train_log, file, indent=2) + # == Save training prompt == + with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_prompt.json", 'w', encoding='utf-8') as file: + json.dump(train_template, file, indent=2) + + def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): + tracked.current_steps += 1 + if WANT_INTERRUPT: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_log(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, logs, **kwargs): + train_log.update(logs) + train_log.update({"current_steps": tracked.current_steps}) + if WANT_INTERRUPT: + print("\033[1;31;1mInterrupted by user\033[0;37;0m") + + print(f"\033[1;30;40mStep: {tracked.current_steps:6} \033[0;37;0m", end='') + + entry = { + 'current_steps': int(train_log.get('current_steps',0)), + 'loss': float(train_log.get('loss', 0.0)), + 'learning_rate': float(train_log.get('learning_rate', 0.0)), + 'epoch': float(train_log.get('epoch', 0.0)) + } + + # Add the entry to the continuous log + train_log_graph.append(entry) + + # Save the graph log for now, we can later generate full graph + with open(f"{lora_file_path}/training_graph.json", 'w') as file: + json.dump(train_log_graph, file, indent=4) + + if 'loss' in logs: + loss = float(logs['loss']) + if loss <= stop_at_loss: + control.should_epoch_stop = True + control.should_training_stop = True + print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m") + + # FPHAM SAMPLE REQ Transformers error handling + sample_req = int(train_data.num_rows)//micro_batch_size + + if sample_req < gradient_accumulation_steps: + print(f"\033[1;31;1mWARNING: Current gradient accumulation is too high for the amount of training data.\033[0;37;0m") + print(f"Gradient accumulation: {gradient_accumulation_steps} should be less than: {sample_req}. \033[1;31;1mThis could crash Accelerate/Transformers\033[0;37;0m") + min_batchSize = sample_req*micro_batch_size + print(f"Preferable fix: \033[1;31;1mIncrease the size of dataset\033[0;37;0m") + print(f"... or Decrerase Batch Size \033[1;31;1m{batch_size}\033[0;37;0m to below {min_batchSize}") + gradient_accumulation_steps = max(1,sample_req-1) + print(f"Last resort fix for this run: Lowering Gradient accumulation to {gradient_accumulation_steps}. [Good luck]") + + else: + print(f"Data Size Check: Gradient accumulation: {gradient_accumulation_steps} <= Data/Batch {sample_req} ... [OK]") + + #END OF FPHAM SAMPLE REQ + + # FPHAM Custom Scheduler == + custom_scheduller = False + lr_scheduler_type_arg = lr_scheduler_type + + if lr_scheduler_type == 'FP_low_epoch_annealing': + custom_scheduller = True + lr_scheduler_type_arg = 'cosine' + + args=transformers.TrainingArguments( + report_to=report_to if report_to != "None" else None, + per_device_train_batch_size=micro_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps), + num_train_epochs=epochs, + learning_rate=actual_lr, + fp16=False if shared.args.cpu else True, + optim=optimizer, + logging_steps=1, + evaluation_strategy="steps" if eval_data is not None else "no", + eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None, + save_strategy="steps" if eval_data is not None else "no", + output_dir=lora_file_path, + lr_scheduler_type=lr_scheduler_type_arg, + load_best_model_at_end=eval_data is not None, + # TODO: Enable multi-device support + ddp_find_unused_parameters=None, + no_cuda=shared.args.cpu, + ) + + if custom_scheduller: + trainer = FPSchedulerTrainer( + model=lora_model, + train_dataset=train_data, + eval_dataset=eval_data, + args=args, + data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False), + callbacks=list([Callbacks()]) + ) + else: + trainer = transformers.Trainer( + model=lora_model, + train_dataset=train_data, + eval_dataset=eval_data, + args=args, + data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False), + callbacks=list([Callbacks()]) + ) + + # END OF FPHAM CUSTOM SCHEDULER + + lora_model.config.use_cache = False + + if torch.__version__ >= "2" and sys.platform != "win32": + lora_model = torch.compile(lora_model) + + # == Save parameters for reuse == + with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file: + vars = locals() + json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2) + + # == Save training prompt == + with open(f"{lora_file_path}/training_prompt.json", 'w', encoding='utf-8') as file: + json.dump(train_template, file, indent=2) + + # == Main run and monitor loop == + logger.info("Starting training...") + yield "Starting..." + + lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model) + + projections_string = ", ".join([projection.replace("_proj", "") for projection in model_to_lora_modules[model_id]]) + + print(f"Training '{model_id}' model using ({projections_string}) projections") + + if lora_all_param > 0: + print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})") + + train_log.update({"base_model_name": shared.model_name}) + train_log.update({"base_model_class": shared.model.__class__.__name__}) + train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)}) + train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)}) + train_log.update({"projections": projections_string}) + + if stop_at_loss > 0: + print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m") + + if WANT_INTERRUPT: + yield "Interrupted before start." + return + + def log_train_dataset(trainer): + decoded_entries = [] + # Try to decode the entries and write the log file + try: + # Iterate over the first 10 elements in the dataset (or fewer if there are less than 10) + for i in range(min(10, len(trainer.train_dataset))): + decoded_text = shared.tokenizer.decode(trainer.train_dataset[i]['input_ids']) + decoded_entries.append({"value": decoded_text}) + + # Write the log file + Path('logs').mkdir(exist_ok=True) + with open(Path('logs/train_dataset_sample.json'), 'w') as json_file: + json.dump(decoded_entries, json_file, indent=4) + + logger.info("Log file 'train_dataset_sample.json' created in the 'logs' directory.") + except Exception as e: + logger.error(f"Failed to create log file due to error: {e}") + + def threaded_run(): + log_train_dataset(trainer) + trainer.train() + # Note: save in the thread in case the gradio thread breaks (eg browser closed) + lora_model.save_pretrained(lora_file_path) + logger.info("LoRA training run is completed and saved.") + # Save log + with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file: + json.dump(train_log, file, indent=2) + + thread = threading.Thread(target=threaded_run) + thread.start() + last_step = 0 + start_time = time.perf_counter() + + while thread.is_alive(): + time.sleep(0.5) + if WANT_INTERRUPT: + yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*" + + elif tracked.current_steps != last_step: + last_step = tracked.current_steps + time_elapsed = time.perf_counter() - start_time + if time_elapsed <= 0: + timer_info = "" + total_time_estimate = 999 + else: + its = tracked.current_steps / time_elapsed + if its > 1: + timer_info = f"`{its:.2f}` it/s" + else: + timer_info = f"`{1.0/its:.2f}` s/it" + + total_time_estimate = (1.0 / its) * (tracked.max_steps) + + yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining" + + # Saving in the train thread might fail if an error occurs, so save here if so. + if not tracked.did_save: + logger.info("Training complete, saving...") + lora_model.save_pretrained(lora_file_path) + + if WANT_INTERRUPT: + logger.info("Training interrupted.") + yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`." + else: + logger.info("Training complete!") + yield f"Done! LoRA saved to `{lora_file_path}`.\n\nBefore testing your new LoRA, make sure to first reload the model, as it is currently dirty from training." + + create_graph(lora_file_path, lora_name) + +def format_time(seconds: float): + if seconds < 120: + return f"`{seconds:.0f}` seconds" + + minutes = seconds / 60 + if minutes < 120: + return f"`{minutes:.0f}` minutes" + + hours = minutes / 60 + return f"`{hours:.0f}` hours" diff --git a/extensions/Training_PRO/train_utils.py b/extensions/Training_PRO/train_utils.py new file mode 100644 index 0000000..21f7d39 --- /dev/null +++ b/extensions/Training_PRO/train_utils.py @@ -0,0 +1,192 @@ +import os +from modules import shared, utils +from pathlib import Path +import json + +def list_subfoldersByTime(directory): + + if not directory.endswith('/'): + directory += '/' + subfolders = [] + path = directory + name_list = os.listdir(path) + full_list = [os.path.join(path,i) for i in name_list] + time_sorted_list = sorted(full_list, key=os.path.getmtime,reverse=True) + + for entry in time_sorted_list: + if os.path.isdir(entry): + entry_str = f"{entry}" # Convert entry to a string + full_path = entry_str + entry_str = entry_str.replace('\\','/') + entry_str = entry_str.replace(f"{directory}", "") # Remove directory part + subfolders.append(entry_str) + + return subfolders + +def get_available_loras_local(_sortedByTime): + + model_dir = shared.args.lora_dir # Update with the appropriate directory path + subfolders = [] + if _sortedByTime: + subfolders = list_subfoldersByTime(model_dir) + else: + subfolders = utils.get_available_loras() + + return subfolders + + +# FPHAM SPLIT BY SENTENCE BLOCK =============== + +def split_sentences(text: str, cutoff_len: int): + sentences = [] + sentence = '' + delimiters = ['. ', '? ', '! ', '... ', '.\n', '?\n', '!\n','...\n','',''] + abbreviations = ['Mr. ', 'Mrs. ', 'Dr. ', 'Ms. ', 'St. ', 'Prof. ', 'Jr. ', 'Ltd. ', 'Capt. ', 'Col. ', 'Gen. ', 'Ave. ', 'Blvd. ', 'Co. ', 'Corp. ', 'Dept. ', 'Est. ', 'Gov. ', 'Inc. ', 'Ph.D. ', 'Univ. '] + errors = 0 + max_cut = cutoff_len-1 + prev_char = '' + + for char in text: + sentence += char + + + if (any(sentence.endswith(delimiter) for delimiter in delimiters) and + not (prev_char.isupper() and len(sentence) >= 3 and sentence[-3] != ' ') and + not any(sentence.endswith(abbreviation) for abbreviation in abbreviations)): + tokens = shared.tokenizer.encode(sentence) + + if len(tokens) > max_cut: + tokens = tokens[:max_cut] + sentence = shared.tokenizer.decode(tokens, skip_special_tokens=True) + errors = errors + 1 + + sentences.append({'text': sentence, 'size': len(tokens)}) + + sentence = '' + + prev_char = char + + if sentence: + tokens = shared.tokenizer.encode(sentence) + if len(tokens) > max_cut: + tokens = tokens[:max_cut] + sentence = shared.tokenizer.decode(tokens, skip_special_tokens=True) + errors = errors + 1 + + sentences.append({'text': sentence, 'size': len(tokens)}) + + if errors > 0: + print(f"Trimmed sentences beyond Cutoff Length: {errors}") + + return sentences + +# The goal of following code is to create blocks of text + overlapping blocks while: +# respects sentence boundaries +# always uses all the text +# hard cut defined by hard_cut_string or will always end at the end of data block +# no overlapping blocks will be created across hard cut or across token + +def precise_cut(text: str, overlap: bool, min_chars_cut: int, eos_to_hc: bool, cutoff_len: int, hard_cut_string: str): + + debug_slicer = False + EOSX_str = '' #hardcut placeholder + EOS_str = '' + print("Precise raw text slicer: ON") + + cut_string = hard_cut_string.replace('\\n', '\n') + text = text.replace(cut_string, EOSX_str) + sentences = split_sentences(text, cutoff_len) + + print(f"Sentences: {len(sentences)}") + sentencelist = [] + currentSentence = '' + totalLength = 0 + max_cut = cutoff_len-1 + half_cut = cutoff_len//2 + halfcut_length = 0 + + edgeindex = [] + half_index = 0 + + for index, item in enumerate(sentences): + + if halfcut_length+ item['size'] < half_cut: + halfcut_length += item['size'] + half_index = index + else: + edgeindex.append(half_index) + halfcut_length = -2 * max_cut + + + if totalLength + item['size'] < max_cut and not currentSentence.endswith(EOSX_str): + currentSentence += item['text'] + totalLength += item['size'] + else: + + if len(currentSentence.strip()) > min_chars_cut: + sentencelist.append(currentSentence.strip()) + + currentSentence = item['text'] + totalLength = item['size'] + halfcut_length = item['size'] + + if len(currentSentence.strip()) > min_chars_cut: + sentencelist.append(currentSentence.strip()) + + unique_blocks = len(sentencelist) + print(f"Text Blocks: {unique_blocks}") + + #overlap strategies: + # don't overlap across HARD CUT (EOSX) + if overlap: + for edge_idx in edgeindex: + currentSentence = '' + totalLength = 0 + + for item in sentences[edge_idx:]: + if totalLength + item['size'] < max_cut: + currentSentence += item['text'] + totalLength += item['size'] + else: + #if by chance EOSX is at the end then it's acceptable + if currentSentence.endswith(EOSX_str) and len(currentSentence.strip()) > min_chars_cut: + sentencelist.append(currentSentence.strip()) + # otherwise don't cross hard cut + elif EOSX_str not in currentSentence and len(currentSentence.strip()) > min_chars_cut: + sentencelist.append(currentSentence.strip()) + + currentSentence = '' + totalLength = 0 + break + + print(f"+ Overlapping blocks: {len(sentencelist)-unique_blocks}") + + num_EOS = 0 + for i in range(len(sentencelist)): + if eos_to_hc: + sentencelist[i] = sentencelist[i].replace(EOSX_str, EOS_str) + else: + sentencelist[i] = sentencelist[i].replace(EOSX_str, '') + + #someone may have had stop strings in the raw text... + sentencelist[i] = sentencelist[i].replace("", EOS_str) + num_EOS += sentencelist[i].count(EOS_str) + + if num_EOS > 0: + print(f"+ EOS count: {num_EOS}") + + #final check for useless lines + sentencelist = [item for item in sentencelist if item.strip() != ""] + sentencelist = [item for item in sentencelist if item.strip() != ""] + + + if debug_slicer: + # Write the log file + Path('logs').mkdir(exist_ok=True) + sentencelist_dict = {index: sentence for index, sentence in enumerate(sentencelist)} + output_file = "logs/sentencelist.json" + with open(output_file, 'w') as f: + json.dump(sentencelist_dict, f,indent=2) + + + return sentencelist \ No newline at end of file diff --git a/extensions/api/blocking_api.py b/extensions/api/blocking_api.py index edc6d8f..a91fd51 100644 --- a/extensions/api/blocking_api.py +++ b/extensions/api/blocking_api.py @@ -7,10 +7,12 @@ from modules import shared from modules.chat import generate_chat_reply from modules.LoRA import add_lora_to_model from modules.models import load_model, unload_model -from modules.models_settings import (get_model_settings_from_yamls, - update_model_parameters) -from modules.text_generation import (encode, generate_reply, - stop_everything_event) +from modules.models_settings import get_model_metadata, update_model_parameters +from modules.text_generation import ( + encode, + generate_reply, + stop_everything_event +) from modules.utils import get_available_models @@ -127,8 +129,8 @@ class Handler(BaseHTTPRequestHandler): shared.model_name = model_name unload_model() - model_settings = get_model_settings_from_yamls(shared.model_name) - shared.settings.update(model_settings) + model_settings = get_model_metadata(shared.model_name) + shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) update_model_parameters(model_settings, initial=True) if shared.settings['mode'] != 'instruct': @@ -195,7 +197,7 @@ class Handler(BaseHTTPRequestHandler): super().end_headers() -def _run_server(port: int, share: bool = False): +def _run_server(port: int, share: bool = False, tunnel_id=str): address = '0.0.0.0' if shared.args.listen else '127.0.0.1' server = ThreadingHTTPServer((address, port), Handler) @@ -205,7 +207,7 @@ def _run_server(port: int, share: bool = False): if share: try: - try_start_cloudflared(port, max_attempts=3, on_start=on_start) + try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start) except Exception: pass else: @@ -215,5 +217,5 @@ def _run_server(port: int, share: bool = False): server.serve_forever() -def start_server(port: int, share: bool = False): - Thread(target=_run_server, args=[port, share], daemon=True).start() +def start_server(port: int, share: bool = False, tunnel_id=str): + Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start() diff --git a/extensions/api/requirements.txt b/extensions/api/requirements.txt index 14e29d3..e4f26c3 100644 --- a/extensions/api/requirements.txt +++ b/extensions/api/requirements.txt @@ -1,2 +1,2 @@ -flask_cloudflared==0.0.12 +flask_cloudflared==0.0.14 websockets==11.0.2 \ No newline at end of file diff --git a/extensions/api/script.py b/extensions/api/script.py index 5d1b1a6..12fd9ca 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -1,8 +1,13 @@ +import time + import extensions.api.blocking_api as blocking_api import extensions.api.streaming_api as streaming_api from modules import shared def setup(): - blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api) - streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api) + blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id) + if shared.args.public_api: + time.sleep(5) + + streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id) diff --git a/extensions/api/streaming_api.py b/extensions/api/streaming_api.py index 88359e3..9175eeb 100644 --- a/extensions/api/streaming_api.py +++ b/extensions/api/streaming_api.py @@ -2,12 +2,15 @@ import asyncio import json from threading import Thread -from websockets.server import serve - -from extensions.api.util import build_parameters, try_start_cloudflared, with_api_lock +from extensions.api.util import ( + build_parameters, + try_start_cloudflared, + with_api_lock +) from modules import shared from modules.chat import generate_chat_reply from modules.text_generation import generate_reply +from websockets.server import serve PATH = '/api/v1/stream' @@ -99,7 +102,7 @@ async def _run(host: str, port: int): await asyncio.Future() # run forever -def _run_server(port: int, share: bool = False): +def _run_server(port: int, share: bool = False, tunnel_id=str): address = '0.0.0.0' if shared.args.listen else '127.0.0.1' def on_start(public_url: str): @@ -108,7 +111,7 @@ def _run_server(port: int, share: bool = False): if share: try: - try_start_cloudflared(port, max_attempts=3, on_start=on_start) + try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start) except Exception as e: print(e) else: @@ -117,5 +120,5 @@ def _run_server(port: int, share: bool = False): asyncio.run(_run(host=address, port=port)) -def start_server(port: int, share: bool = False): - Thread(target=_run_server, args=[port, share], daemon=True).start() +def start_server(port: int, share: bool = False, tunnel_id=str): + Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start() diff --git a/extensions/api/util.py b/extensions/api/util.py index a89365c..e4f7738 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -10,7 +10,6 @@ from modules import shared from modules.chat import load_character_memoized from modules.presets import load_preset_memoized - # We use a thread local to store the asyncio lock, so that each thread # has its own lock. This isn't strictly necessary, but it makes it # such that if we can support multiple worker threads in the future, @@ -22,6 +21,8 @@ def build_parameters(body, chat=False): generate_params = { 'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))), + 'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)), + 'max_tokens_second': int(body.get('max_tokens_second', 0)), 'do_sample': bool(body.get('do_sample', True)), 'temperature': float(body.get('temperature', 0.5)), 'top_p': float(body.get('top_p', 1)), @@ -43,9 +44,12 @@ def build_parameters(body, chat=False): 'mirostat_mode': int(body.get('mirostat_mode', 0)), 'mirostat_tau': float(body.get('mirostat_tau', 5)), 'mirostat_eta': float(body.get('mirostat_eta', 0.1)), + 'guidance_scale': float(body.get('guidance_scale', 1)), + 'negative_prompt': str(body.get('negative_prompt', '')), 'seed': int(body.get('seed', -1)), 'add_bos_token': bool(body.get('add_bos_token', True)), 'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))), + 'custom_token_bans': str(body.get('custom_token_bans', '')), 'ban_eos_token': bool(body.get('ban_eos_token', False)), 'skip_special_tokens': bool(body.get('skip_special_tokens', True)), 'custom_stopping_strings': '', # leave this blank @@ -59,34 +63,37 @@ def build_parameters(body, chat=False): if chat: character = body.get('character') - instruction_template = body.get('instruction_template') - name1, name2, _, greeting, context, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), shared.settings['name2'], instruct=False) + instruction_template = body.get('instruction_template', shared.settings['instruction_template']) + if str(instruction_template) == "None": + instruction_template = "Vicuna-v1.1" + if str(character) == "None": + character = "Assistant" + + name1, name2, _, greeting, context, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), '', instruct=False) name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True) generate_params.update({ - 'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])), - 'chat_generation_attempts': int(body.get('chat_generation_attempts', shared.settings['chat_generation_attempts'])), 'mode': str(body.get('mode', 'chat')), - 'name1': name1, - 'name2': name2, - 'context': context, - 'greeting': greeting, - 'name1_instruct': name1_instruct, - 'name2_instruct': name2_instruct, - 'context_instruct': context_instruct, - 'turn_template': turn_template, - 'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])), + 'name1': str(body.get('name1', name1)), + 'name2': str(body.get('name2', name2)), + 'context': str(body.get('context', context)), + 'greeting': str(body.get('greeting', greeting)), + 'name1_instruct': str(body.get('name1_instruct', name1_instruct)), + 'name2_instruct': str(body.get('name2_instruct', name2_instruct)), + 'context_instruct': str(body.get('context_instruct', context_instruct)), + 'turn_template': str(body.get('turn_template', turn_template)), + 'chat-instruct_command': str(body.get('chat_instruct_command', body.get('chat-instruct_command', shared.settings['chat-instruct_command']))), 'history': body.get('history', {'internal': [], 'visible': []}) }) return generate_params -def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): +def try_start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): Thread(target=_start_cloudflared, args=[ - port, max_attempts, on_start], daemon=True).start() + port, tunnel_id, max_attempts, on_start], daemon=True).start() -def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): +def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): try: from flask_cloudflared import _run_cloudflared except ImportError: @@ -96,7 +103,10 @@ def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Call for _ in range(max_attempts): try: - public_url = _run_cloudflared(port, port + 1) + if tunnel_id is not None: + public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id) + else: + public_url = _run_cloudflared(port, port + 1) if on_start: on_start(public_url) diff --git a/extensions/elevenlabs_tts/requirements.txt b/extensions/elevenlabs_tts/requirements.txt index 2cfc196..c3c0cc7 100644 --- a/extensions/elevenlabs_tts/requirements.txt +++ b/extensions/elevenlabs_tts/requirements.txt @@ -1 +1 @@ -elevenlabs==0.2.* +elevenlabs==0.2.24 diff --git a/extensions/elevenlabs_tts/script.py b/extensions/elevenlabs_tts/script.py index c5e6174..68ae16b 100644 --- a/extensions/elevenlabs_tts/script.py +++ b/extensions/elevenlabs_tts/script.py @@ -1,10 +1,12 @@ +import html import re from pathlib import Path import elevenlabs import gradio as gr -from modules import chat, shared +from modules import chat, shared, ui_chat +from modules.logging_colors import logger from modules.utils import gradio params = { @@ -13,10 +15,12 @@ params = { 'selected_voice': 'None', 'autoplay': False, 'show_text': True, + 'model': 'eleven_monolingual_v1', } voices = None wav_idx = 0 +LANG_MODELS = ['eleven_monolingual_v1', 'eleven_multilingual_v1'] def update_api_key(key): @@ -108,7 +112,7 @@ def output_modifier(string): output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.mp3'.format(wav_idx)) print(f'Outputting audio to {str(output_file)}') try: - audio = elevenlabs.generate(text=string, voice=params['selected_voice'], model="eleven_monolingual_v1") + audio = elevenlabs.generate(text=html.unescape(string), voice=params['selected_voice'], model=params['model']) elevenlabs.save(audio, str(output_file)) autoplay = 'autoplay' if params['autoplay'] else '' @@ -132,7 +136,12 @@ def ui(): global voices if not voices: voices = refresh_voices() - params['selected_voice'] = voices[0] + selected = params['selected_voice'] + if selected == 'None': + params['selected_voice'] = voices[0] + elif selected not in voices: + logger.error(f'Selected voice {selected} not available, switching to {voices[0]}') + params['selected_voice'] = voices[0] # Gradio elements with gr.Row(): @@ -145,36 +154,43 @@ def ui(): refresh = gr.Button(value='Refresh') with gr.Row(): - api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key') + if params['api_key']: + api_key = gr.Textbox(value=params['api_key'], label='API Key') + update_api_key(params['api_key']) + else: + api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key') + + with gr.Row(): + model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model') with gr.Row(): convert = gr.Button('Permanently replace audios with the message texts') convert_cancel = gr.Button('Cancel', visible=False) convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False) - if shared.is_chat(): - # Convert history with confirmation - convert_arr = [convert_confirm, convert, convert_cancel] - convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) - convert_confirm.click( - lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then( - remove_tts_from_history, gradio('history'), gradio('history')).then( - chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then( - chat.redraw_html, shared.reload_inputs, gradio('display')) + # Convert history with confirmation + convert_arr = [convert_confirm, convert, convert_cancel] + convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) + convert_confirm.click( + lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then( + remove_tts_from_history, gradio('history'), gradio('history')).then( + chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then( + chat.redraw_html, gradio(ui_chat.reload_arr), gradio('display')) - convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) + convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) - # Toggle message text in history - show_text.change( - lambda x: params.update({"show_text": x}), show_text, None).then( - toggle_text_in_history, gradio('history'), gradio('history')).then( - chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then( - chat.redraw_html, shared.reload_inputs, gradio('display')) + # Toggle message text in history + show_text.change( + lambda x: params.update({"show_text": x}), show_text, None).then( + toggle_text_in_history, gradio('history'), gradio('history')).then( + chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then( + chat.redraw_html, gradio(ui_chat.reload_arr), gradio('display')) # Event functions to update the parameters in the backend activate.change(lambda x: params.update({'activate': x}), activate, None) voice.change(lambda x: params.update({'selected_voice': x}), voice, None) api_key.change(update_api_key, api_key, None) + model.change(lambda x: params.update({'model': x}), model, None) # connect.click(check_valid_api, [], connection_status) refresh.click(refresh_voices_dd, [], voice) # Event functions to update the parameters in the backend diff --git a/extensions/example/script.py b/extensions/example/script.py new file mode 100644 index 0000000..44f0cb3 --- /dev/null +++ b/extensions/example/script.py @@ -0,0 +1,139 @@ +""" +An example of extension. It does nothing, but you can add transformations +before the return statements to customize the webui behavior. + +Starting from history_modifier and ending in output_modifier, the +functions are declared in the same order that they are called at +generation time. +""" + +import gradio as gr +import torch +from transformers import LogitsProcessor + +from modules import chat, shared +from modules.text_generation import ( + decode, + encode, + generate_reply, +) + +params = { + "display_name": "Example Extension", + "is_tab": False, +} + +class MyLogits(LogitsProcessor): + """ + Manipulates the probabilities for the next token before it gets sampled. + Used in the logits_processor_modifier function below. + """ + def __init__(self): + pass + + def __call__(self, input_ids, scores): + # probs = torch.softmax(scores, dim=-1, dtype=torch.float) + # probs[0] /= probs[0].sum() + # scores = torch.log(probs / (1 - probs)) + return scores + +def history_modifier(history): + """ + Modifies the chat history. + Only used in chat mode. + """ + return history + +def state_modifier(state): + """ + Modifies the state variable, which is a dictionary containing the input + values in the UI like sliders and checkboxes. + """ + return state + +def chat_input_modifier(text, visible_text, state): + """ + Modifies the user input string in chat mode (visible_text). + You can also modify the internal representation of the user + input (text) to change how it will appear in the prompt. + """ + return text, visible_text + +def input_modifier(string, state, is_chat=False): + """ + In default/notebook modes, modifies the whole prompt. + + In chat mode, it is the same as chat_input_modifier but only applied + to "text", here called "string", and not to "visible_text". + """ + return string + +def bot_prefix_modifier(string, state): + """ + Modifies the prefix for the next bot reply in chat mode. + By default, the prefix will be something like "Bot Name:". + """ + return string + +def tokenizer_modifier(state, prompt, input_ids, input_embeds): + """ + Modifies the input ids and embeds. + Used by the multimodal extension to put image embeddings in the prompt. + Only used by loaders that use the transformers library for sampling. + """ + return prompt, input_ids, input_embeds + +def logits_processor_modifier(processor_list, input_ids): + """ + Adds logits processors to the list, allowing you to access and modify + the next token probabilities. + Only used by loaders that use the transformers library for sampling. + """ + processor_list.append(MyLogits()) + return processor_list + +def output_modifier(string, state, is_chat=False): + """ + Modifies the LLM output before it gets presented. + + In chat mode, the modified version goes into history['visible'], + and the original version goes into history['internal']. + """ + return string + +def custom_generate_chat_prompt(user_input, state, **kwargs): + """ + Replaces the function that generates the prompt from the chat history. + Only used in chat mode. + """ + result = chat.generate_chat_prompt(user_input, state, **kwargs) + return result + +def custom_css(): + """ + Returns a CSS string that gets appended to the CSS for the webui. + """ + return '' + +def custom_js(): + """ + Returns a javascript string that gets appended to the javascript + for the webui. + """ + return '' + +def setup(): + """ + Gets executed only once, when the extension is imported. + """ + pass + +def ui(): + """ + Gets executed when the UI is drawn. Custom gradio elements and + their corresponding event handlers should be defined here. + + To learn about gradio components, check out the docs: + https://gradio.app/docs/ + """ + pass diff --git a/extensions/gallery/script.js b/extensions/gallery/script.js new file mode 100644 index 0000000..4ff23af --- /dev/null +++ b/extensions/gallery/script.js @@ -0,0 +1,33 @@ +let gallery_element = document.getElementById('gallery-extension'); +let chat_mode_element = document.getElementById('chat-mode'); + +let extensions_block = document.getElementById('extensions'); +let extensions_block_size = extensions_block.childNodes.length; +let gallery_only = (extensions_block_size == 5); + +document.querySelector('.header_bar').addEventListener('click', function(event) { + if (event.target.tagName === 'BUTTON') { + const buttonText = event.target.textContent.trim(); + + let chat_visible = (buttonText == 'Chat'); + let default_visible = (buttonText == 'Default'); + let notebook_visible = (buttonText == 'Notebook'); + let chat_mode_visible = (chat_mode_element.offsetHeight > 0 && chat_mode_element.offsetWidth > 0); + + // Only show this extension in the Chat tab + if (chat_visible) { + if (chat_mode_visible) { + gallery_element.style.display = 'block'; + extensions_block.style.display = ''; + } else { + gallery_element.style.display = 'none'; + extensions_block.style.display = 'none'; + } + } else { + gallery_element.style.display = 'none'; + if (gallery_only) { + extensions_block.style.display = 'none'; + } + } + } +}); diff --git a/extensions/gallery/script.py b/extensions/gallery/script.py index 993ef27..611a11f 100644 --- a/extensions/gallery/script.py +++ b/extensions/gallery/script.py @@ -82,8 +82,13 @@ def select_character(evt: gr.SelectData): return (evt.value[1]) +def custom_js(): + path_to_js = Path(__file__).parent.resolve() / 'script.js' + return open(path_to_js, 'r').read() + + def ui(): - with gr.Accordion("Character gallery", open=False): + with gr.Accordion("Character gallery", open=False, elem_id='gallery-extension'): update = gr.Button("Refresh") gr.HTML(value="") gallery = gr.Dataset(components=[gr.HTML(visible=False)], diff --git a/extensions/google_translate/script.py b/extensions/google_translate/script.py index 5dfdbcd..784668c 100644 --- a/extensions/google_translate/script.py +++ b/extensions/google_translate/script.py @@ -1,3 +1,5 @@ +import html + import gradio as gr from deep_translator import GoogleTranslator @@ -27,7 +29,8 @@ def output_modifier(string): if not params['activate']: return string - return GoogleTranslator(source='en', target=params['language string']).translate(string) + translated_str = GoogleTranslator(source='en', target=params['language string']).translate(html.unescape(string)) + return html.escape(translated_str) def bot_prefix_modifier(string): diff --git a/extensions/llava/script.py b/extensions/llava/script.py deleted file mode 100644 index 781d584..0000000 --- a/extensions/llava/script.py +++ /dev/null @@ -1,8 +0,0 @@ -import gradio as gr - -from modules.logging_colors import logger - - -def ui(): - gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead") - logger.error("LLaVA extension is deprecated, use \"multimodal\" extension instead") diff --git a/extensions/long_replies/script.py b/extensions/long_replies/script.py new file mode 100644 index 0000000..035e8c9 --- /dev/null +++ b/extensions/long_replies/script.py @@ -0,0 +1,143 @@ +import torch +from modules import chat, shared +from modules.text_generation import ( + decode, + encode, + generate_reply, +) +from transformers import LogitsProcessor +import gradio as gr + +params = { + "display_name": "Long replies", + "is_tab": False, + "min_length": 120, +} + +initial_size = 0 + +class MyLogits(LogitsProcessor): + """ + Manipulates the probabilities for the next token before it gets sampled. + Used in the logits_processor_modifier function below. + """ + def __init__(self): + self.newline_id = shared.tokenizer.encode('\n')[-1] + pass + + def __call__(self, input_ids, scores): + if input_ids.shape[-1] - initial_size < params["min_length"]: + scores[...,self.newline_id] = -1000 + # scores[...,shared.tokenizer.eos_token_id] = -1000 + + # probs = torch.softmax(scores, dim=-1, dtype=torch.float) + # probs[0] /= probs[0].sum() + # scores = torch.log(probs / (1 - probs)) + return scores + +def history_modifier(history): + """ + Modifies the chat history. + Only used in chat mode. + """ + return history + +def state_modifier(state): + """ + Modifies the state variable, which is a dictionary containing the input + values in the UI like sliders and checkboxes. + """ + return state + +def chat_input_modifier(text, visible_text, state): + """ + Modifies the user input string in chat mode (visible_text). + You can also modify the internal representation of the user + input (text) to change how it will appear in the prompt. + """ + return text, visible_text + +def input_modifier(string, state): + """ + In default/notebook modes, modifies the whole prompt. + + In chat mode, it is the same as chat_input_modifier but only applied + to "text", here called "string", and not to "visible_text". + """ + return string + +def bot_prefix_modifier(string, state): + """ + Modifies the prefix for the next bot reply in chat mode. + By default, the prefix will be something like "Bot Name:". + """ + return string + +def tokenizer_modifier(state, prompt, input_ids, input_embeds): + """ + Modifies the input ids and embeds. + Used by the multimodal extension to put image embeddings in the prompt. + Only used by loaders that use the transformers library for sampling. + """ + + global initial_size + initial_size = input_ids.shape[-1] + + return prompt, input_ids, input_embeds + +def logits_processor_modifier(processor_list, input_ids): + """ + Adds logits processors to the list, allowing you to access and modify + the next token probabilities. + Only used by loaders that use the transformers library for sampling. + """ + processor_list.append(MyLogits()) + return processor_list + +def output_modifier(string, state): + """ + Modifies the LLM output before it gets presented. + + In chat mode, the modified version goes into history['visible'], + and the original version goes into history['internal']. + """ + return string + +def custom_generate_chat_prompt(user_input, state, **kwargs): + """ + Replaces the function that generates the prompt from the chat history. + Only used in chat mode. + """ + result = chat.generate_chat_prompt(user_input, state, **kwargs) + return result + +def custom_css(): + """ + Returns a CSS string that gets appended to the CSS for the webui. + """ + return '' + +def custom_js(): + """ + Returns a javascript string that gets appended to the javascript + for the webui. + """ + return '' + +def setup(): + """ + Gets executed only once, when the extension is imported. + """ + pass + +def ui(): + """ + Gets executed when the UI is drawn. Custom gradio elements and + their corresponding event handlers should be defined here. + + To learn about gradio components, check out the docs: + https://gradio.app/docs/ + """ + + min_length = gr.Slider(0, 800, step=10, value=params['min_length'], label='Minimum reply length') + min_length.change(lambda x: params.update({'min_length': x}), min_length, None) diff --git a/extensions/multimodal/README.md b/extensions/multimodal/README.md index 0f515ae..5068103 100644 --- a/extensions/multimodal/README.md +++ b/extensions/multimodal/README.md @@ -11,10 +11,10 @@ https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b To run this extension, download a LLM that supports multimodality, and then start server.py with the appropriate `--multimodal-pipeline` argument. Examples: ``` -python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b --chat -python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b --chat -python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b --chat -python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b --chat +python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b +python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b +python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b +python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b ``` There is built-in support for LLaVA-v0-13B and LLaVA-v0-7b. To install `minigpt4`: @@ -38,6 +38,8 @@ As of now, the following multimodal pipelines are supported: |[LLaVA 7B](https://github.com/haotian-liu/LLaVA)|`llava-7b`|[LLaVA 7B](https://huggingface.co/wojtab/llava-7b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in| |[MiniGPT-4 7B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-7b`|[Vicuna v0 7B](https://huggingface.co/TheBloke/vicuna-7B-GPTQ-4bit-128g)|GPTQ 4-bit quant, new format|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)| |[MiniGPT-4 13B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-13b`|[Vicuna v0 13B](https://huggingface.co/anon8231489123/vicuna-13b-GPTQ-4bit-128g)|GPTQ 4-bit quant, old CUDA|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)| +|[InstructBLIP 7B](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip)|`instructblip-7b`|[Vicuna v1.1 7B](https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g)|GPTQ 4-bit quant|[kjerk/instructblip-pipeline](https://github.com/kjerk/instructblip-pipeline)| +|[InstructBLIP 13B](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip)|`instructblip-13b`|[Vicuna v1.1 13B](https://huggingface.co/TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g)|GPTQ 4-bit quant|[kjerk/instructblip-pipeline](https://github.com/kjerk/instructblip-pipeline)| Some pipelines could support different LLMs but do note that while it might work, it isn't a supported configuration. diff --git a/extensions/multimodal/pipelines/llava/llava.py b/extensions/multimodal/pipelines/llava/llava.py index eca2be5..306ab22 100644 --- a/extensions/multimodal/pipelines/llava/llava.py +++ b/extensions/multimodal/pipelines/llava/llava.py @@ -56,10 +56,13 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline): @staticmethod def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor: - if hasattr(shared.model.model, 'embed_tokens'): - func = shared.model.model.embed_tokens + for attr in ['', 'model', 'model.model', 'model.model.model']: + tmp = getattr(shared.model, attr, None) if attr != '' else shared.model + if tmp is not None and hasattr(tmp, 'embed_tokens'): + func = tmp.embed_tokens + break else: - func = shared.model.model.model.embed_tokens # AutoGPTQ case + raise ValueError('The embed_tokens method has not been found for this loader.') return func(input_ids).to(shared.model.device, dtype=shared.model.dtype) diff --git a/extensions/multimodal/script.py b/extensions/multimodal/script.py index b3f654e..8bc2631 100644 --- a/extensions/multimodal/script.py +++ b/extensions/multimodal/script.py @@ -35,6 +35,15 @@ input_hijack = { multimodal_embedder: MultimodalEmbedder = None +def chat_input_modifier(text, visible_text, state): + global input_hijack + if input_hijack['state']: + input_hijack['state'] = False + return input_hijack['value'](text, visible_text) + else: + return text, visible_text + + def add_chat_picture(picture, text, visible_text): # resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable) max_hw, min_hw = max(picture.size), min(picture.size) diff --git a/extensions/ngrok/script.py b/extensions/ngrok/script.py index 782deea..46f39bd 100644 --- a/extensions/ngrok/script.py +++ b/extensions/ngrok/script.py @@ -1,8 +1,8 @@ # Adds ngrok ingress, to use add `--extension ngrok` to the command line options # -# Parameters can be customized in settings.json of webui, e.g.: +# Parameters can be customized in settings.json of webui, e.g.: # {"ngrok": {"basic_auth":"user:password"} } -# or +# or # {"ngrok": {"oauth_provider":"google", "oauth_allow_emails":["asdf@asdf.com"]} } # # See this example for full list of options: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py @@ -22,6 +22,7 @@ options = { 'session_metadata': 'text-generation-webui', } + def ui(): settings = shared.settings.get("ngrok") if settings: @@ -33,4 +34,3 @@ def ui(): logging.info(f"Ingress established at: {tunnel.url()}") except ModuleNotFoundError: logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`") - diff --git a/extensions/openai/README.md b/extensions/openai/README.md index 0f775bb..c2054a5 100644 --- a/extensions/openai/README.md +++ b/extensions/openai/README.md @@ -1,17 +1,16 @@ # An OpenedAI API (openai like) This extension creates an API that works kind of like openai (ie. api.openai.com). -It's incomplete so far but perhaps is functional enough for you. -## Setup & installation +## Setup & installation -Optional (for flask_cloudflared, embeddings): +Install the requirements: ``` pip3 install -r requirements.txt ``` -It listens on tcp port 5001 by default. You can use the OPENEDAI_PORT environment variable to change this. +It listens on `tcp port 5001` by default. You can use the `OPENEDAI_PORT` environment variable to change this. Make sure you enable it in server launch parameters, it should include: @@ -19,15 +18,44 @@ Make sure you enable it in server launch parameters, it should include: --extensions openai ``` -You can also use the ``--listen`` argument to make the server available on the networ, and/or the ```--share``` argument to enable a public Cloudflare endpoint. +You can also use the `--listen` argument to make the server available on the networ, and/or the `--share` argument to enable a public Cloudflare endpoint. -To enable the basic image generation support (txt2img) set the environment variable SD_WEBUI_URL to point to your Stable Diffusion API ([Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)). +To enable the basic image generation support (txt2img) set the environment variable `SD_WEBUI_URL` to point to your Stable Diffusion API ([Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)). For example: + ``` SD_WEBUI_URL=http://127.0.0.1:7861 ``` +## Quick start + +1. Install the requirements.txt (pip) +2. Enable the `openeai` module (--extensions openai), restart the server. +3. Configure the openai client + +Most openai application can be configured to connect the API if you set the following environment variables: + +```shell +# Sample .env file: +OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111 +OPENAI_API_BASE=http://0.0.0.0:5001/v1 +``` + +If needed, replace 0.0.0.0 with the IP/port of your server. + + +## Settings + +To adjust your default settings, you can add the following to your `settings.yaml` file. + +``` +openai-port: 5002 +openai-embedding_device: cuda +openai-sd_webui_url: http://127.0.0.1:7861 +openai-debug: 1 +``` + ### Models This has been successfully tested with Alpaca, Koala, Vicuna, WizardLM and their variants, (ex. gpt4-x-alpaca, GPT4all-snoozy, stable-vicuna, wizard-vicuna, etc.) and many others. Models that have been trained for **Instruction Following** work best. If you test with other models please let me know how it goes. Less than satisfying results (so far) from: RWKV-4-Raven, llama, mpt-7b-instruct/chat. @@ -36,9 +64,9 @@ For best results across all API endpoints, a model like [vicuna-13b-v1.3-GPTQ](h For good results with the [Completions](https://platform.openai.com/docs/api-reference/completions) API endpoint, in addition to the above models, you can also try using a base model like [falcon-7b](https://huggingface.co/tiiuae/falcon-7b) or Llama. -For good results with the [ChatCompletions](https://platform.openai.com/docs/api-reference/chat) or [Edits](https://platform.openai.com/docs/api-reference/edits) API endpoints you can use almost any model trained for instruction following - within the limits of the model. Be sure that the proper instruction template is detected and loaded or the results will not be good. +For good results with the [ChatCompletions](https://platform.openai.com/docs/api-reference/chat) or [Edits](https://platform.openai.com/docs/api-reference/edits) API endpoints you can use almost any model trained for instruction following. Be sure that the proper instruction template is detected and loaded or the results will not be good. -For the proper instruction format to be detected you need to have a matching model entry in your ```models/config.yaml``` file. Be sure to keep this file up to date. +For the proper instruction format to be detected you need to have a matching model entry in your `models/config.yaml` file. Be sure to keep this file up to date. A matching instruction template file in the characters/instruction-following/ folder will loaded and applied to format messages correctly for the model - this is critical for good results. For example, the Wizard-Vicuna family of models are trained with the Vicuna 1.1 format. In the models/config.yaml file there is this matching entry: @@ -49,7 +77,7 @@ For example, the Wizard-Vicuna family of models are trained with the Vicuna 1.1 instruction_template: 'Vicuna-v1.1' ``` -This refers to ```characters/instruction-following/Vicuna-v1.1.yaml```, which looks like this: +This refers to `characters/instruction-following/Vicuna-v1.1.yaml`, which looks like this: ``` user: "USER:" @@ -61,63 +89,66 @@ context: "A chat between a curious user and an artificial intelligence assistant For most common models this is already setup, but if you are using a new or uncommon model you may need add a matching entry to the models/config.yaml and possibly create your own instruction-following template and for best results. If you see this in your logs, it probably means that the correct format could not be loaded: + ``` Warning: Loaded default instruction-following template for model. ``` ### Embeddings (alpha) -Embeddings requires ```sentence-transformers``` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: ```sentence-transformers/all-mpnet-base-v2``` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default ```text-embedding-ada-002``` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future. +Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default `text-embedding-ada-002` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future. -| model name | dimensions | input max tokens | speed | size | Avg. performance | -| --- | --- | --- | --- | --- | --- | -| text-embedding-ada-002 | 1536 | 8192| - | - | - | -| text-davinci-002 | 768 | 2046 | - | - | - | -| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 | -| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 | +| model name | dimensions | input max tokens | speed | size | Avg. performance | +| ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- | +| text-embedding-ada-002 | 1536 | 8192 | - | - | - | +| text-davinci-002 | 768 | 2046 | - | - | - | +| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 | +| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 | -In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable OPENEDAI_EMBEDDING_MODEL, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2". +In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable `OPENEDAI_EMBEDDING_MODEL`, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2". Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable. ### Client Application Setup - Almost everything you use it with will require you to set a dummy OpenAI API key environment variable. -With the [official python openai client](https://github.com/openai/openai-python), you can set the OPENAI_API_BASE environment variable before you import the openai module, like so: +With the [official python openai client](https://github.com/openai/openai-python), set the `OPENAI_API_BASE` environment variables: -``` +```shell +# Sample .env file: OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111 -OPENAI_API_BASE=http://127.0.0.1:5001/v1 +OPENAI_API_BASE=http://0.0.0.0:5001/v1 ``` -If needed, replace 127.0.0.1 with the IP/port of your server. +If needed, replace 0.0.0.0 with the IP/port of your server. -If using .env files to save the OPENAI_API_BASE and OPENAI_API_KEY variables, you can ensure compatibility by loading the .env file before loading the openai module, like so in python: +If using .env files to save the `OPENAI_API_BASE` and `OPENAI_API_KEY` variables, make sure the .env file is loaded before the openai module is imported: -``` +```python from dotenv import load_dotenv -load_dotenv() +load_dotenv() # make sure the environment variables are set before import import openai ``` With the [official Node.js openai client](https://github.com/openai/openai-node) it is slightly more more complex because the environment variables are not used by default, so small source code changes may be required to use the environment variables, like so: -``` -const openai = OpenAI(Configuration({ - apiKey: process.env.OPENAI_API_KEY, - basePath: process.env.OPENAI_API_BASE, -})); +```js +const openai = OpenAI( + Configuration({ + apiKey: process.env.OPENAI_API_KEY, + basePath: process.env.OPENAI_API_BASE + }) +); ``` For apps made with the [chatgpt-api Node.js client library](https://github.com/transitive-bullshit/chatgpt-api): -``` +```js const api = new ChatGPTAPI({ apiKey: process.env.OPENAI_API_KEY, - apiBaseUrl: process.env.OPENAI_API_BASE, -}) + apiBaseUrl: process.env.OPENAI_API_BASE +}); ``` ## API Documentation & Examples @@ -127,106 +158,99 @@ The OpenAI API is well documented, you can view the documentation here: https:// Examples of how to use the Completions API in Python can be found here: https://platform.openai.com/examples Not all of them will work with all models unfortunately, See the notes on Models for how to get the best results. -Here is a simple python example of how you can use the Edit endpoint as a translator. +Here is a simple python example. ```python +import os +os.environ['OPENAI_API_KEY']="sk-111111111111111111111111111111111111111111111111" +os.environ['OPENAI_API_BASE']="http://0.0.0.0:5001/v1" import openai -response = openai.Edit.create( + +response = openai.ChatCompletion.create( model="x", - instruction="Translate this into French", - input="Our mission is to ensure that artificial general intelligence benefits all of humanity.", + messages = [{ 'role': 'system', 'content': "Answer in a consistent style." }, + {'role': 'user', 'content': "Teach me about patience."}, + {'role': 'assistant', 'content': "The river that carves the deepest valley flows from a modest spring; the grandest symphony originates from a single note; the most intricate tapestry begins with a solitary thread."}, + {'role': 'user', 'content': "Teach me about the ocean."}, + ] ) -print(response['choices'][0]['text']) -# Sample Output: -# Notre mission est de garantir que l'intelligence artificielle généralisée profite à tous les membres de l'humanité. +text = response['choices'][0]['message']['content'] +print(text) ``` - - ## Compatibility & not so compatibility -| API endpoint | tested with | notes | -| --- | --- | --- | -| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options | -| /v1/models/{id} | openai.Model.get() | returns whatever you ask for, model does nothing yet anyways | -| /v1/text_completion | openai.Completion.create() | the most tested, only supports single string input so far, variable quality based on the model | -| /v1/chat/completions | openai.ChatCompletion.create() | Quality depends a lot on the model | -| /v1/edits | openai.Edit.create() | Works the best of all, perfect for instruction following models | -| /v1/images/generations | openai.Image.create() | Bare bones, no model configuration, response_format='b64_json' only. | -| /v1/embeddings | openai.Embedding.create() | Using Sentence Transformer, dimensions are different and may never be directly comparable to openai embeddings. | -| /v1/moderations | openai.Moderation.create() | does nothing. successfully. | -| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) | -| /v1/engines/*/embeddings | python-openai v0.25 | Legacy endpoint | -| /v1/engines/*/generate | openai engines.generate | Legacy endpoint | -| /v1/engines | openai engines.list | Legacy Lists models | -| /v1/engines/{model_name} | openai engines.get -i {model_name} | You can use this legacy endpoint to load models via the api | -| /v1/images/edits | openai.Image.create_edit() | not yet supported | -| /v1/images/variations | openai.Image.create_variation() | not yet supported | -| /v1/audio/\* | openai.Audio.\* | not yet supported | -| /v1/files\* | openai.Files.\* | not yet supported | -| /v1/fine-tunes\* | openai.FineTune.\* | not yet supported | -| /v1/search | openai.search, engines.search | not yet supported | +| API endpoint | tested with | notes | +| ------------------------- | ---------------------------------- | --------------------------------------------------------------------------- | +| /v1/chat/completions | openai.ChatCompletion.create() | Use it with instruction following models | +| /v1/embeddings | openai.Embedding.create() | Using SentenceTransformer embeddings | +| /v1/images/generations | openai.Image.create() | Bare bones, no model configuration, response_format='b64_json' only. | +| /v1/moderations | openai.Moderation.create() | Basic initial support via embeddings | +| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options | +| /v1/models/{id} | openai.Model.get() | returns whatever you ask for | +| /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models | +| /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model | +| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) | +| /v1/engines/\*/embeddings | python-openai v0.25 | Legacy endpoint | +| /v1/engines/\*/generate | openai engines.generate | Legacy endpoint | +| /v1/engines | openai engines.list | Legacy Lists models | +| /v1/engines/{model_name} | openai engines.get -i {model_name} | You can use this legacy endpoint to load models via the api or command line | +| /v1/images/edits | openai.Image.create_edit() | not yet supported | +| /v1/images/variations | openai.Image.create_variation() | not yet supported | +| /v1/audio/\* | openai.Audio.\* | supported | +| /v1/files\* | openai.Files.\* | not yet supported | +| /v1/fine-tunes\* | openai.FineTune.\* | not yet supported | +| /v1/search | openai.search, engines.search | not yet supported | -The model name setting is ignored in completions, but you may need to adjust the maximum token length to fit the model (ie. set to <2048 tokens instead of 4096, 8k, etc). To mitigate some of this, the max_tokens value is halved until it is less than truncation_length for the model (typically 2k). +Because of the differences in OpenAI model context sizes (2k, 4k, 8k, 16k, etc,) you may need to adjust the max_tokens to fit into the context of the model you choose. Streaming, temperature, top_p, max_tokens, stop, should all work as expected, but not all parameters are mapped correctly. Some hacky mappings: -| OpenAI | text-generation-webui | note | -| --- | --- | --- | -| frequency_penalty | encoder_repetition_penalty | this seems to operate with a different scale and defaults, I tried to scale it based on range & defaults, but the results are terrible. hardcoded to 1.18 until there is a better way | -| presence_penalty | repetition_penalty | same issues as frequency_penalty, hardcoded to 1.0 | -| best_of | top_k | default is 1 | -| stop | custom_stopping_strings | this is also stuffed with ['\n###', "\n{user prompt}", "{user prompt}" ] for good measure. | -| n | 1 | variations are not supported yet. | -| 1 | num_beams | hardcoded to 1 | -| 1.0 | typical_p | hardcoded to 1.0 | -| max_tokens | max_new_tokens | For Text Completions max_tokens is set smaller than the truncation_length minus the prompt length. This can cause no input to be generated if the prompt is too large. For ChatCompletions, the older chat messages may be dropped to fit the max_new_tokens requested | -| logprobs | - | not supported yet | -| logit_bias | - | not supported yet | -| messages.name | - | not supported yet | -| user | - | not supported yet | -| functions/function_call | - | function calls are not supported yet | - -defaults are mostly from openai, so are different. I use the openai defaults where I can and try to scale them to the webui defaults with the same intent. +| OpenAI | text-generation-webui | note | +| ----------------------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| model | - | Ignored, the model is not changed | +| frequency_penalty | encoder_repetition_penalty | this seems to operate with a different scale and defaults, I tried to scale it based on range & defaults, but the results are terrible. hardcoded to 1.18 until there is a better way | +| presence_penalty | repetition_penalty | same issues as frequency_penalty, hardcoded to 1.0 | +| best_of | top_k | default is 1 (top_k is 20 for chat, which doesn't support best_of) | +| n | 1 | variations are not supported yet. | +| 1 | num_beams | hardcoded to 1 | +| 1.0 | typical_p | hardcoded to 1.0 | +| logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' | +| messages.name | - | not supported yet | +| suffix | - | not supported yet | +| user | - | not supported yet | +| functions/function_call | - | function calls are not supported yet | ### Applications -Almost everything needs the OPENAI_API_KEY environment variable set, for example: -``` -OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111 -``` -Some apps are picky about key format, but 'dummy' or 'sk-dummy' also work in most cases. -Most application will work if you also set: -``` -OPENAI_API_BASE=http://127.0.0.1:5001/v1 -``` -but there are some exceptions. +Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment variable set, but there are some exceptions. -| Compatibility | Application/Library | url | notes / setting | -| --- | --- | --- | --- | -| ✅❌ | openai-python (v0.25+) | https://github.com/openai/openai-python | only the endpoints from above are working. OPENAI_API_BASE=http://127.0.0.1:5001/v1 | -| ✅❌ | openai-node | https://github.com/openai/openai-node | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) | -| ✅❌ | chatgpt-api | https://github.com/transitive-bullshit/chatgpt-api | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) | -| ✅ | anse | https://github.com/anse-app/anse | API Key & URL configurable in UI | -| ✅ | shell_gpt | https://github.com/TheR1D/shell_gpt | OPENAI_API_HOST=http://127.0.0.1:5001 | -| ✅ | gpt-shell | https://github.com/jla/gpt-shell | OPENAI_API_BASE=http://127.0.0.1:5001/v1 | -| ✅ | gpt-discord-bot | https://github.com/openai/gpt-discord-bot | OPENAI_API_BASE=http://127.0.0.1:5001/v1 | -| ✅ | OpenAI for Notepad++ | https://github.com/Krazal/nppopenai | api_url=http://127.0.0.1:5001 in the config file, or environment variables | -| ✅ | vscode-openai | https://marketplace.visualstudio.com/items?itemName=AndrewButson.vscode-openai | OPENAI_API_BASE=http://127.0.0.1:5001/v1 | -| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. | -| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context | -| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 | +| Compatibility | Application/Library | Website | Notes | +| ------------- | ---------------------- | ------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| ✅❌ | openai-python (v0.25+) | https://github.com/openai/openai-python | only the endpoints from above are working. OPENAI_API_BASE=http://127.0.0.1:5001/v1 | +| ✅❌ | openai-node | https://github.com/openai/openai-node | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) | +| ✅❌ | chatgpt-api | https://github.com/transitive-bullshit/chatgpt-api | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) | +| ✅ | anse | https://github.com/anse-app/anse | API Key & URL configurable in UI, Images also work | +| ✅ | shell_gpt | https://github.com/TheR1D/shell_gpt | OPENAI_API_HOST=http://127.0.0.1:5001 | +| ✅ | gpt-shell | https://github.com/jla/gpt-shell | OPENAI_API_BASE=http://127.0.0.1:5001/v1 | +| ✅ | gpt-discord-bot | https://github.com/openai/gpt-discord-bot | OPENAI_API_BASE=http://127.0.0.1:5001/v1 | +| ✅ | OpenAI for Notepad++ | https://github.com/Krazal/nppopenai | api_url=http://127.0.0.1:5001 in the config file, or environment variables | +| ✅ | vscode-openai | https://marketplace.visualstudio.com/items?itemName=AndrewButson.vscode-openai | OPENAI_API_BASE=http://127.0.0.1:5001/v1 | +| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. | +| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context | +| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 | +| ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported | ## Future plans -* better error handling -* model changing, esp. something for swapping loras or embedding models -* consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard) -* do something about rate limiting or locking requests for completions, most systems will only be able handle a single request at a time before OOM + +- better error handling +- model changing, esp. something for swapping loras or embedding models +- consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard) ## Bugs? Feedback? Comments? Pull requests? -To enable debugging and get copious output you can set the OPENEDAI_DEBUG=1 environment variable. +To enable debugging and get copious output you can set the `OPENEDAI_DEBUG=1` environment variable. Are all appreciated, please @matatonic and I'll try to get back to you as soon as possible. diff --git a/extensions/openai/cache_embedding_model.py b/extensions/openai/cache_embedding_model.py index 44ac1dc..2dd6cb2 100644 --- a/extensions/openai/cache_embedding_model.py +++ b/extensions/openai/cache_embedding_model.py @@ -3,6 +3,10 @@ # Dockerfile: # ENV OPENEDAI_EMBEDDING_MODEL=all-mpnet-base-v2 # Optional # RUN python3 cache_embedded_model.py -import os, sentence_transformers -st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" +import os + +import sentence_transformers +from extensions.openai.script import params + +st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", params.get('embedding_model', 'all-mpnet-base-v2')) model = sentence_transformers.SentenceTransformer(st_model) diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py new file mode 100644 index 0000000..40d96c1 --- /dev/null +++ b/extensions/openai/completions.py @@ -0,0 +1,637 @@ +import time + +import tiktoken +import torch +import torch.nn.functional as F +import yaml +from extensions.openai.defaults import clamp, default, get_default_req_params +from extensions.openai.errors import InvalidRequestError +from extensions.openai.utils import debug_msg, end_line +from modules import shared +from modules.text_generation import decode, encode, generate_reply +from transformers import LogitsProcessor, LogitsProcessorList + + +# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic +class LogitsBiasProcessor(LogitsProcessor): + def __init__(self, logit_bias={}): + self.logit_bias = logit_bias + if self.logit_bias: + self.keys = list([int(key) for key in self.logit_bias.keys()]) + values = [self.logit_bias[str(key)] for key in self.keys] + self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device) + debug_msg(f"{self})") + + def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: + if self.logit_bias: + debug_msg(logits[0, self.keys], " + ", self.values) + logits[0, self.keys] += self.values + debug_msg(" --> ", logits[0, self.keys]) + debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0]))) + return logits + + def __repr__(self): + return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>" + + +class LogprobProcessor(LogitsProcessor): + def __init__(self, logprobs=None): + self.logprobs = logprobs + self.token_alternatives = {} + + def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: + if self.logprobs is not None: # 0-5 + log_e_probabilities = F.log_softmax(logits, dim=1) + top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1) + top_tokens = [decode(tok) for tok in top_indices[0]] + top_probs = [float(x) for x in top_values[0]] + self.token_alternatives = dict(zip(top_tokens, top_probs)) + debug_msg(repr(self)) + return logits + + def __repr__(self): + return f"<{self.__class__.__name__}(logprobs={self.logprobs}, token_alternatives={self.token_alternatives})>" + + +def convert_logprobs_to_tiktoken(model, logprobs): + # more problems than it's worth. + # try: + # encoder = tiktoken.encoding_for_model(model) + # # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall. + # return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()]) + # except KeyError: + # # assume native tokens if we can't find the tokenizer + # return logprobs + + return logprobs + + +def marshal_common_params(body): + # Request Parameters + # Try to use openai defaults or map them to something with the same intent + + req_params = get_default_req_params() + + # Common request parameters + req_params['truncation_length'] = shared.settings['truncation_length'] + req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token']) + req_params['seed'] = shared.settings.get('seed', req_params['seed']) + req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings'] + + # OpenAI API Parameters + # model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this + req_params['requested_model'] = body.get('model', shared.model_name) + + req_params['suffix'] = default(body, 'suffix', req_params['suffix']) + req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.01, 1.99) # fixup absolute 0.0/2.0 + req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.01, 1.0) + n = default(body, 'n', 1) + if n != 1: + raise InvalidRequestError(message="Only n = 1 is supported.", param='n') + + if 'stop' in body: # str or array, max len 4 (ignored) + if isinstance(body['stop'], str): + req_params['stopping_strings'] = [body['stop']] # non-standard parameter + elif isinstance(body['stop'], list): + req_params['stopping_strings'] = body['stop'] + + # presence_penalty - ignored + # frequency_penalty - ignored + + # pass through unofficial params + req_params['repetition_penalty'] = default(body, 'repetition_penalty', req_params['repetition_penalty']) + req_params['encoder_repetition_penalty'] = default(body, 'encoder_repetition_penalty', req_params['encoder_repetition_penalty']) + + # user - ignored + + logits_processor = [] + logit_bias = body.get('logit_bias', None) + if logit_bias: # {str: float, ...} + # XXX convert tokens from tiktoken based on requested model + # Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100} + try: + encoder = tiktoken.encoding_for_model(req_params['requested_model']) + new_logit_bias = {} + for logit, bias in logit_bias.items(): + for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]: + if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens + continue + new_logit_bias[str(int(x))] = bias + debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias) + logit_bias = new_logit_bias + except KeyError: + pass # assume native tokens if we can't find the tokenizer + + logits_processor = [LogitsBiasProcessor(logit_bias)] + + logprobs = None # coming to chat eventually + if 'logprobs' in body: + logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5. + req_params['logprob_proc'] = LogprobProcessor(logprobs) + logits_processor.extend([req_params['logprob_proc']]) + else: + logprobs = None + + if logits_processor: # requires logits_processor support + req_params['logits_processor'] = LogitsProcessorList(logits_processor) + + return req_params + + +def messages_to_prompt(body: dict, req_params: dict, max_tokens): + # functions + if body.get('functions', []): # chat only + raise InvalidRequestError(message="functions is not supported.", param='functions') + if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'} + raise InvalidRequestError(message="function_call is not supported.", param='function_call') + + if 'messages' not in body: + raise InvalidRequestError(message="messages is required", param='messages') + + messages = body['messages'] + + role_formats = { + 'user': 'User: {message}\n', + 'assistant': 'Assistant: {message}\n', + 'system': '{message}', + 'context': 'You are a helpful assistant. Answer as concisely as possible.\nUser: I want your assistance.\nAssistant: Sure! What can I do for you?', + 'prompt': 'Assistant:', + } + + if 'stopping_strings' not in req_params: + req_params['stopping_strings'] = [] + + # Instruct models can be much better + if shared.settings['instruction_template']: + try: + instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r')) + + template = instruct['turn_template'] + system_message_template = "{message}" + system_message_default = instruct.get('context', '') # can be missing + bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token + user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', '')) + bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', '')) + bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ') + + role_formats = { + 'user': user_message_template, + 'assistant': bot_message_template, + 'system': system_message_template, + 'context': system_message_default, + 'prompt': bot_prompt, + } + + if 'Alpaca' in shared.settings['instruction_template']: + req_params['stopping_strings'].extend(['\n###']) + elif instruct['user']: # WizardLM and some others have no user prompt. + req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']]) + + debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}") + + except Exception as e: + req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also + + print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}") + print("Warning: Loaded default instruction-following template for model.") + + else: + req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also + print("Warning: Loaded default instruction-following template for model.") + + system_msgs = [] + chat_msgs = [] + + # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date} + context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else '' + context_msg = end_line(context_msg) + + # Maybe they sent both? This is not documented in the API, but some clients seem to do this. + if 'prompt' in body: + context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg + + for m in messages: + if 'role' not in m: + raise InvalidRequestError(message="messages: missing role", param='messages') + if 'content' not in m: + raise InvalidRequestError(message="messages: missing content", param='messages') + + role = m['role'] + content = m['content'] + # name = m.get('name', None) + # function_call = m.get('function_call', None) # user name or function name with output in content + msg = role_formats[role].format(message=content) + if role == 'system': + system_msgs.extend([msg]) + elif role == 'function': + raise InvalidRequestError(message="role: function is not supported.", param='messages') + else: + chat_msgs.extend([msg]) + + system_msg = '\n'.join(system_msgs) + system_msg = end_line(system_msg) + + prompt = system_msg + context_msg + ''.join(chat_msgs) + role_formats['prompt'] + + token_count = len(encode(prompt)[0]) + + if token_count >= req_params['truncation_length']: + err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens." + raise InvalidRequestError(message=err_msg, param='messages') + + if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']: + err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}." + print(f"Warning: ${err_msg}") + # raise InvalidRequestError(message=err_msg, params='max_tokens') + + return prompt, token_count + + +def chat_completions(body: dict, is_legacy: bool = False) -> dict: + # Chat Completions + object_type = 'chat.completions' + created_time = int(time.time()) + cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000)) + resp_list = 'data' if is_legacy else 'choices' + + # common params + req_params = marshal_common_params(body) + req_params['stream'] = False + requested_model = req_params.pop('requested_model') + logprob_proc = req_params.pop('logprob_proc', None) + req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k. + + # chat default max_tokens is 'inf', but also flexible + max_tokens = 0 + max_tokens_str = 'length' if is_legacy else 'max_tokens' + if max_tokens_str in body: + max_tokens = default(body, max_tokens_str, req_params['truncation_length']) + req_params['max_new_tokens'] = max_tokens + else: + req_params['max_new_tokens'] = req_params['truncation_length'] + + # format the prompt from messages + prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings'] + + # set real max, avoid deeper errors + if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']: + req_params['max_new_tokens'] = req_params['truncation_length'] - token_count + + stopping_strings = req_params.pop('stopping_strings', []) + + # generate reply ####################################### + debug_msg({'prompt': prompt, 'req_params': req_params}) + generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) + + answer = '' + for a in generator: + answer = a + + # strip extra leading space off new generated content + if answer and answer[0] == ' ': + answer = answer[1:] + + completion_token_count = len(encode(answer)[0]) + stop_reason = "stop" + if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']: + stop_reason = "length" + + resp = { + "id": cmpl_id, + "object": object_type, + "created": created_time, + "model": shared.model_name, # TODO: add Lora info? + resp_list: [{ + "index": 0, + "finish_reason": stop_reason, + "message": {"role": "assistant", "content": answer} + }], + "usage": { + "prompt_tokens": token_count, + "completion_tokens": completion_token_count, + "total_tokens": token_count + completion_token_count + } + } + if logprob_proc: # not official for chat yet + top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) + resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} + # else: + # resp[resp_list][0]["logprobs"] = None + + return resp + + +# generator +def stream_chat_completions(body: dict, is_legacy: bool = False): + + # Chat Completions + stream_object_type = 'chat.completions.chunk' + created_time = int(time.time()) + cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000)) + resp_list = 'data' if is_legacy else 'choices' + + # common params + req_params = marshal_common_params(body) + req_params['stream'] = True + requested_model = req_params.pop('requested_model') + logprob_proc = req_params.pop('logprob_proc', None) + req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k. + + # chat default max_tokens is 'inf', but also flexible + max_tokens = 0 + max_tokens_str = 'length' if is_legacy else 'max_tokens' + if max_tokens_str in body: + max_tokens = default(body, max_tokens_str, req_params['truncation_length']) + req_params['max_new_tokens'] = max_tokens + else: + req_params['max_new_tokens'] = req_params['truncation_length'] + + # format the prompt from messages + prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings'] + + # set real max, avoid deeper errors + if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']: + req_params['max_new_tokens'] = req_params['truncation_length'] - token_count + + def chat_streaming_chunk(content): + # begin streaming + chunk = { + "id": cmpl_id, + "object": stream_object_type, + "created": created_time, + "model": shared.model_name, + resp_list: [{ + "index": 0, + "finish_reason": None, + # So yeah... do both methods? delta and messages. + "message": {'role': 'assistant', 'content': content}, + "delta": {'role': 'assistant', 'content': content}, + }], + } + + if logprob_proc: # not official for chat yet + top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) + chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} + # else: + # chunk[resp_list][0]["logprobs"] = None + return chunk + + yield chat_streaming_chunk('') + + # generate reply ####################################### + debug_msg({'prompt': prompt, 'req_params': req_params}) + + stopping_strings = req_params.pop('stopping_strings', []) + + generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) + + answer = '' + seen_content = '' + completion_token_count = 0 + + for a in generator: + answer = a + + len_seen = len(seen_content) + new_content = answer[len_seen:] + + if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. + continue + + seen_content = answer + + # strip extra leading space off new generated content + if len_seen == 0 and new_content[0] == ' ': + new_content = new_content[1:] + + chunk = chat_streaming_chunk(new_content) + + yield chunk + + # to get the correct token_count, strip leading space if present + if answer and answer[0] == ' ': + answer = answer[1:] + + completion_token_count = len(encode(answer)[0]) + stop_reason = "stop" + if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']: + stop_reason = "length" + + chunk = chat_streaming_chunk('') + chunk[resp_list][0]['finish_reason'] = stop_reason + chunk['usage'] = { + "prompt_tokens": token_count, + "completion_tokens": completion_token_count, + "total_tokens": token_count + completion_token_count + } + + yield chunk + + +def completions(body: dict, is_legacy: bool = False): + # Legacy + # Text Completions + object_type = 'text_completion' + created_time = int(time.time()) + cmpl_id = "conv-%d" % (int(time.time() * 1000000000)) + resp_list = 'data' if is_legacy else 'choices' + + # ... encoded as a string, array of strings, array of tokens, or array of token arrays. + prompt_str = 'context' if is_legacy else 'prompt' + if prompt_str not in body: + raise InvalidRequestError("Missing required input", param=prompt_str) + + prompt_arg = body[prompt_str] + if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)): + prompt_arg = [prompt_arg] + + # common params + req_params = marshal_common_params(body) + req_params['stream'] = False + max_tokens_str = 'length' if is_legacy else 'max_tokens' + max_tokens = default(body, max_tokens_str, req_params['max_new_tokens']) + req_params['max_new_tokens'] = max_tokens + requested_model = req_params.pop('requested_model') + logprob_proc = req_params.pop('logprob_proc', None) + stopping_strings = req_params.pop('stopping_strings', []) + # req_params['suffix'] = default(body, 'suffix', req_params['suffix']) + req_params['echo'] = default(body, 'echo', req_params['echo']) + req_params['top_k'] = default(body, 'best_of', req_params['top_k']) + + resp_list_data = [] + total_completion_token_count = 0 + total_prompt_token_count = 0 + + for idx, prompt in enumerate(prompt_arg, start=0): + if isinstance(prompt[0], int): + # token lists + if requested_model == shared.model_name: + prompt = decode(prompt)[0] + else: + try: + encoder = tiktoken.encoding_for_model(requested_model) + prompt = encoder.decode(prompt) + except KeyError: + prompt = decode(prompt)[0] + + token_count = len(encode(prompt)[0]) + total_prompt_token_count += token_count + + if token_count + max_tokens > req_params['truncation_length']: + err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})." + # print(f"Warning: ${err_msg}") + raise InvalidRequestError(message=err_msg, param=max_tokens_str) + + # generate reply ####################################### + debug_msg({'prompt': prompt, 'req_params': req_params}) + generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) + answer = '' + + for a in generator: + answer = a + + # strip extra leading space off new generated content + if answer and answer[0] == ' ': + answer = answer[1:] + + completion_token_count = len(encode(answer)[0]) + total_completion_token_count += completion_token_count + stop_reason = "stop" + if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens: + stop_reason = "length" + + respi = { + "index": idx, + "finish_reason": stop_reason, + "text": answer, + "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, + } + + resp_list_data.extend([respi]) + + resp = { + "id": cmpl_id, + "object": object_type, + "created": created_time, + "model": shared.model_name, # TODO: add Lora info? + resp_list: resp_list_data, + "usage": { + "prompt_tokens": total_prompt_token_count, + "completion_tokens": total_completion_token_count, + "total_tokens": total_prompt_token_count + total_completion_token_count + } + } + + return resp + + +# generator +def stream_completions(body: dict, is_legacy: bool = False): + # Legacy + # Text Completions + # object_type = 'text_completion' + stream_object_type = 'text_completion.chunk' + created_time = int(time.time()) + cmpl_id = "conv-%d" % (int(time.time() * 1000000000)) + resp_list = 'data' if is_legacy else 'choices' + + # ... encoded as a string, array of strings, array of tokens, or array of token arrays. + prompt_str = 'context' if is_legacy else 'prompt' + if prompt_str not in body: + raise InvalidRequestError("Missing required input", param=prompt_str) + + prompt = body[prompt_str] + req_params = marshal_common_params(body) + requested_model = req_params.pop('requested_model') + if isinstance(prompt, list): + if prompt and isinstance(prompt[0], int): + try: + encoder = tiktoken.encoding_for_model(requested_model) + prompt = encoder.decode(prompt) + except KeyError: + prompt = decode(prompt)[0] + else: + raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str) + + # common params + req_params['stream'] = True + max_tokens_str = 'length' if is_legacy else 'max_tokens' + max_tokens = default(body, max_tokens_str, req_params['max_new_tokens']) + req_params['max_new_tokens'] = max_tokens + logprob_proc = req_params.pop('logprob_proc', None) + stopping_strings = req_params.pop('stopping_strings', []) + # req_params['suffix'] = default(body, 'suffix', req_params['suffix']) + req_params['echo'] = default(body, 'echo', req_params['echo']) + req_params['top_k'] = default(body, 'best_of', req_params['top_k']) + + token_count = len(encode(prompt)[0]) + + if token_count + max_tokens > req_params['truncation_length']: + err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})." + # print(f"Warning: ${err_msg}") + raise InvalidRequestError(message=err_msg, param=max_tokens_str) + + def text_streaming_chunk(content): + # begin streaming + chunk = { + "id": cmpl_id, + "object": stream_object_type, + "created": created_time, + "model": shared.model_name, + resp_list: [{ + "index": 0, + "finish_reason": None, + "text": content, + "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, + }], + } + + return chunk + + yield text_streaming_chunk('') + + # generate reply ####################################### + debug_msg({'prompt': prompt, 'req_params': req_params}) + generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) + + answer = '' + seen_content = '' + completion_token_count = 0 + + for a in generator: + answer = a + + len_seen = len(seen_content) + new_content = answer[len_seen:] + + if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. + continue + + seen_content = answer + + # strip extra leading space off new generated content + if len_seen == 0 and new_content[0] == ' ': + new_content = new_content[1:] + + chunk = text_streaming_chunk(new_content) + + yield chunk + + # to get the correct count, we strip the leading space if present + if answer and answer[0] == ' ': + answer = answer[1:] + + completion_token_count = len(encode(answer)[0]) + stop_reason = "stop" + if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens: + stop_reason = "length" + + chunk = text_streaming_chunk('') + chunk[resp_list][0]["finish_reason"] = stop_reason + chunk["usage"] = { + "prompt_tokens": token_count, + "completion_tokens": completion_token_count, + "total_tokens": token_count + completion_token_count + } + + yield chunk diff --git a/extensions/openai/defaults.py b/extensions/openai/defaults.py new file mode 100644 index 0000000..7bc5ab2 --- /dev/null +++ b/extensions/openai/defaults.py @@ -0,0 +1,73 @@ +import copy + +# Slightly different defaults for OpenAI's API +# Data type is important, Ex. use 0.0 for a float 0 +default_req_params = { + 'max_new_tokens': 16, # 'Inf' for chat + 'auto_max_new_tokens': False, + 'max_tokens_second': 0, + 'temperature': 1.0, + 'top_p': 1.0, + 'top_k': 1, # choose 20 for chat in absence of another default + 'repetition_penalty': 1.18, + 'repetition_penalty_range': 0, + 'encoder_repetition_penalty': 1.0, + 'suffix': None, + 'stream': False, + 'echo': False, + 'seed': -1, + # 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map + 'truncation_length': 2048, # first use shared.settings value + 'add_bos_token': True, + 'do_sample': True, + 'typical_p': 1.0, + 'epsilon_cutoff': 0.0, # In units of 1e-4 + 'eta_cutoff': 0.0, # In units of 1e-4 + 'tfs': 1.0, + 'top_a': 0.0, + 'min_length': 0, + 'no_repeat_ngram_size': 0, + 'num_beams': 1, + 'penalty_alpha': 0.0, + 'length_penalty': 1.0, + 'early_stopping': False, + 'mirostat_mode': 0, + 'mirostat_tau': 5.0, + 'mirostat_eta': 0.1, + 'guidance_scale': 1, + 'negative_prompt': '', + 'ban_eos_token': False, + 'custom_token_bans': '', + 'skip_special_tokens': True, + 'custom_stopping_strings': '', + # 'logits_processor' - conditionally passed + # 'stopping_strings' - temporarily used + # 'logprobs' - temporarily used + # 'requested_model' - temporarily used +} + + +def get_default_req_params(): + return copy.deepcopy(default_req_params) + + +def default(dic, key, default): + ''' + little helper to get defaults if arg is present but None and should be the same type as default. + ''' + val = dic.get(key, default) + if not isinstance(val, type(default)): + # maybe it's just something like 1 instead of 1.0 + try: + v = type(default)(val) + if type(val)(v) == val: # if it's the same value passed in, it's ok. + return v + except: + pass + + val = default + return val + + +def clamp(value, minvalue, maxvalue): + return max(minvalue, min(value, maxvalue)) diff --git a/extensions/openai/edits.py b/extensions/openai/edits.py new file mode 100644 index 0000000..edf4e6c --- /dev/null +++ b/extensions/openai/edits.py @@ -0,0 +1,101 @@ +import time + +import yaml +from extensions.openai.defaults import get_default_req_params +from extensions.openai.errors import InvalidRequestError +from extensions.openai.utils import debug_msg +from modules import shared +from modules.text_generation import encode, generate_reply + + +def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict: + + created_time = int(time.time() * 1000) + + # Request parameters + req_params = get_default_req_params() + stopping_strings = [] + + # Alpaca is verbose so a good default prompt + default_template = ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + ) + + instruction_template = default_template + + # Use the special instruction/input/response template for anything trained like Alpaca + if shared.settings['instruction_template']: + if 'Alpaca' in shared.settings['instruction_template']: + stopping_strings.extend(['\n###']) + else: + try: + instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r')) + + template = instruct['turn_template'] + template = template\ + .replace('<|user|>', instruct.get('user', ''))\ + .replace('<|bot|>', instruct.get('bot', ''))\ + .replace('<|user-message|>', '{instruction}\n{input}') + + instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ') + if instruct['user']: + stopping_strings.extend(['\n' + instruct['user'], instruct['user']]) + + except Exception as e: + instruction_template = default_template + print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}") + print("Warning: Loaded default instruction-following template (Alpaca) for model.") + else: + stopping_strings.extend(['\n###']) + print("Warning: Loaded default instruction-following template (Alpaca) for model.") + + edit_task = instruction_template.format(instruction=instruction, input=input) + + truncation_length = shared.settings['truncation_length'] + + token_count = len(encode(edit_task)[0]) + max_tokens = truncation_length - token_count + + if max_tokens < 1: + err_msg = f"This model maximum context length is {truncation_length} tokens. However, your messages resulted in over {truncation_length - max_tokens} tokens." + raise InvalidRequestError(err_msg, param='input') + + req_params['max_new_tokens'] = max_tokens + req_params['truncation_length'] = truncation_length + req_params['temperature'] = temperature + req_params['top_p'] = top_p + req_params['seed'] = shared.settings.get('seed', req_params['seed']) + req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token']) + req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings'] + + debug_msg({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count}) + + generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False) + + answer = '' + for a in generator: + answer = a + + # some reply's have an extra leading space to fit the instruction template, just clip it off from the reply. + if edit_task[-1] != '\n' and answer and answer[0] == ' ': + answer = answer[1:] + + completion_token_count = len(encode(answer)[0]) + + resp = { + "object": "edit", + "created": created_time, + "choices": [{ + "text": answer, + "index": 0, + }], + "usage": { + "prompt_tokens": token_count, + "completion_tokens": completion_token_count, + "total_tokens": token_count + completion_token_count + } + } + + return resp diff --git a/extensions/openai/embeddings.py b/extensions/openai/embeddings.py new file mode 100644 index 0000000..96f44d9 --- /dev/null +++ b/extensions/openai/embeddings.py @@ -0,0 +1,80 @@ +import os + +import numpy as np +from extensions.openai.errors import ServiceUnavailableError +from extensions.openai.utils import debug_msg, float_list_to_base64 +from sentence_transformers import SentenceTransformer + +embeddings_params_initialized = False +# using 'lazy loading' to avoid circular import +# so this function will be executed only once +def initialize_embedding_params(): + global embeddings_params_initialized + if not embeddings_params_initialized: + global st_model, embeddings_model, embeddings_device + from extensions.openai.script import params + st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", params.get('embedding_model', 'all-mpnet-base-v2')) + embeddings_model = None + # OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone + embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", params.get('embedding_device', 'cpu')) + if embeddings_device.lower() == 'auto': + embeddings_device = None + embeddings_params_initialized = True + + +def load_embedding_model(model: str) -> SentenceTransformer: + initialize_embedding_params() + global embeddings_device, embeddings_model + try: + embeddings_model = 'loading...' # flag + # see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer + emb_model = SentenceTransformer(model, device=embeddings_device) + # ... emb_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM + print(f"\nLoaded embedding model: {model} on {emb_model.device} [always seems to say 'cpu', even if 'cuda'], max sequence length: {emb_model.max_seq_length}") + except Exception as e: + embeddings_model = None + raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e)) + + return emb_model + + +def get_embeddings_model() -> SentenceTransformer: + initialize_embedding_params() + global embeddings_model, st_model + if st_model and not embeddings_model: + embeddings_model = load_embedding_model(st_model) # lazy load the model + return embeddings_model + + +def get_embeddings_model_name() -> str: + initialize_embedding_params() + global st_model + return st_model + + +def get_embeddings(input: list) -> np.ndarray: + return get_embeddings_model().encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False, device=embeddings_device) + + +def embeddings(input: list, encoding_format: str) -> dict: + + embeddings = get_embeddings(input) + + if encoding_format == "base64": + data = [{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} for n, emb in enumerate(embeddings)] + else: + data = [{"object": "embedding", "embedding": emb.tolist(), "index": n} for n, emb in enumerate(embeddings)] + + response = { + "object": "list", + "data": data, + "model": st_model, # return the real model + "usage": { + "prompt_tokens": 0, + "total_tokens": 0, + } + } + + debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}") + + return response diff --git a/extensions/openai/errors.py b/extensions/openai/errors.py new file mode 100644 index 0000000..838d1e7 --- /dev/null +++ b/extensions/openai/errors.py @@ -0,0 +1,31 @@ +class OpenAIError(Exception): + def __init__(self, message=None, code=500, internal_message=''): + self.message = message + self.code = code + self.internal_message = internal_message + + def __repr__(self): + return "%s(message=%r, code=%d)" % ( + self.__class__.__name__, + self.message, + self.code, + ) + + +class InvalidRequestError(OpenAIError): + def __init__(self, message, param, code=400, internal_message=''): + super().__init__(message, code, internal_message) + self.param = param + + def __repr__(self): + return "%s(message=%r, code=%d, param=%s)" % ( + self.__class__.__name__, + self.message, + self.code, + self.param, + ) + + +class ServiceUnavailableError(OpenAIError): + def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''): + super().__init__(message, code, internal_message) diff --git a/extensions/openai/images.py b/extensions/openai/images.py new file mode 100644 index 0000000..350ea61 --- /dev/null +++ b/extensions/openai/images.py @@ -0,0 +1,68 @@ +import os +import time + +import requests +from extensions.openai.errors import ServiceUnavailableError + + +def generations(prompt: str, size: str, response_format: str, n: int): + # Stable Diffusion callout wrapper for txt2img + # Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E + # the results will be limited and likely poor. SD has hundreds of models and dozens of settings. + # If you want high quality tailored results you should just use the Stable Diffusion API directly. + # it's too general an API to try and shape the result with specific tags like negative prompts + # or "masterpiece", etc. SD configuration is beyond the scope of this API. + # At this point I will not add the edits and variations endpoints (ie. img2img) because they + # require changing the form data handling to accept multipart form data, also to properly support + # url return types will require file management and a web serving files... Perhaps later! + base_model_size = 512 if 'SD_BASE_MODEL_SIZE' not in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512)) + sd_defaults = { + 'sampler_name': 'DPM++ 2M Karras', # vast improvement + 'steps': 30, + } + + width, height = [int(x) for x in size.split('x')] # ignore the restrictions on size + + # to hack on better generation, edit default payload. + payload = { + 'prompt': prompt, # ignore prompt limit of 1000 characters + 'width': width, + 'height': height, + 'batch_size': n, + } + payload.update(sd_defaults) + + scale = min(width, height) / base_model_size + if scale >= 1.2: + # for better performance with the default size (1024), and larger res. + scaler = { + 'width': width // scale, + 'height': height // scale, + 'hr_scale': scale, + 'enable_hr': True, + 'hr_upscaler': 'Latent', + 'denoising_strength': 0.68, + } + payload.update(scaler) + + resp = { + 'created': int(time.time()), + 'data': [] + } + from extensions.openai.script import params + # TODO: support SD_WEBUI_AUTH username:password pair. + sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img" + + response = requests.post(url=sd_url, json=payload) + r = response.json() + if response.status_code != 200 or 'images' not in r: + print(r) + raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None)) + # r['parameters']... + for b64_json in r['images']: + if response_format == 'b64_json': + resp['data'].extend([{'b64_json': b64_json}]) + else: + resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this + + return resp diff --git a/extensions/openai/models.py b/extensions/openai/models.py new file mode 100644 index 0000000..83e550f --- /dev/null +++ b/extensions/openai/models.py @@ -0,0 +1,78 @@ +from extensions.openai.embeddings import get_embeddings_model_name +from extensions.openai.errors import OpenAIError +from modules import shared +from modules.models import load_model as _load_model +from modules.models import unload_model +from modules.models_settings import get_model_metadata, update_model_parameters +from modules.utils import get_available_models + + +def get_current_model_list() -> list: + return [shared.model_name] # The real chat/completions model, maybe "None" + + +def get_pseudo_model_list() -> list: + return [ # these are expected by so much, so include some here as a dummy + 'gpt-3.5-turbo', + 'text-embedding-ada-002', + ] + + +def load_model(model_name: str) -> dict: + resp = { + "id": model_name, + "object": "engine", + "owner": "self", + "ready": True, + } + if model_name not in get_pseudo_model_list() + [get_embeddings_model_name()] + get_current_model_list(): # Real model only + # No args. Maybe it works anyways! + # TODO: hack some heuristics into args for better results + + shared.model_name = model_name + unload_model() + + model_settings = get_model_metadata(shared.model_name) + shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) + update_model_parameters(model_settings, initial=True) + + if shared.settings['mode'] != 'instruct': + shared.settings['instruction_template'] = None + + shared.model, shared.tokenizer = _load_model(shared.model_name) + + if not shared.model: # load failed. + shared.model_name = "None" + raise OpenAIError(f"Model load failed for: {shared.model_name}") + + return resp + + +def list_models(is_legacy: bool = False) -> dict: + # TODO: Lora's? + all_model_list = get_current_model_list() + [get_embeddings_model_name()] + get_pseudo_model_list() + get_available_models() + + models = {} + + if is_legacy: + models = [{"id": id, "object": "engine", "owner": "user", "ready": True} for id in all_model_list] + if not shared.model: + models[0]['ready'] = False + else: + models = [{"id": id, "object": "model", "owned_by": "user", "permission": []} for id in all_model_list] + + resp = { + "object": "list", + "data": models, + } + + return resp + + +def model_info(model_name: str) -> dict: + return { + "id": model_name, + "object": "model", + "owned_by": "user", + "permission": [] + } diff --git a/extensions/openai/moderations.py b/extensions/openai/moderations.py new file mode 100644 index 0000000..1d2d4c1 --- /dev/null +++ b/extensions/openai/moderations.py @@ -0,0 +1,68 @@ +import time + +import numpy as np +from extensions.openai.embeddings import get_embeddings +from numpy.linalg import norm + +moderations_disabled = False # return 0/false +category_embeddings = None +antonym_embeddings = None +categories = ["sexual", "hate", "harassment", "self-harm", "sexual/minors", "hate/threatening", "violence/graphic", "self-harm/intent", "self-harm/instructions", "harassment/threatening", "violence"] +flag_threshold = 0.5 + + +def get_category_embeddings() -> dict: + global category_embeddings, categories + if category_embeddings is None: + embeddings = get_embeddings(categories).tolist() + category_embeddings = dict(zip(categories, embeddings)) + + return category_embeddings + + +def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: + return np.dot(a, b) / (norm(a) * norm(b)) + + +# seems most openai like with all-mpnet-base-v2 +def mod_score(a: np.ndarray, b: np.ndarray) -> float: + return 2.0 * np.dot(a, b) + + +def moderations(input): + global category_embeddings, categories, flag_threshold, moderations_disabled + results = { + "id": f"modr-{int(time.time()*1e9)}", + "model": "text-moderation-001", + "results": [], + } + + if moderations_disabled: + results['results'] = [{ + 'categories': dict([(C, False) for C in categories]), + 'category_scores': dict([(C, 0.0) for C in categories]), + 'flagged': False, + }] + return results + + category_embeddings = get_category_embeddings() + + # input, string or array + if isinstance(input, str): + input = [input] + + for in_str in input: + for ine in get_embeddings([in_str]): + category_scores = dict([(C, mod_score(category_embeddings[C], ine)) for C in categories]) + category_flags = dict([(C, bool(category_scores[C] > flag_threshold)) for C in categories]) + flagged = any(category_flags.values()) + + results['results'].extend([{ + 'flagged': flagged, + 'categories': category_flags, + 'category_scores': category_scores, + }]) + + print(results) + + return results diff --git a/extensions/openai/requirements.txt b/extensions/openai/requirements.txt index 5193a0a..8c63b5e 100644 --- a/extensions/openai/requirements.txt +++ b/extensions/openai/requirements.txt @@ -1,2 +1,4 @@ +SpeechRecognition==3.10.0 flask_cloudflared==0.0.12 -sentence-transformers \ No newline at end of file +sentence-transformers +tiktoken diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 323d682..b44fc53 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -1,108 +1,40 @@ -import base64 import json import os -import time -import requests -import yaml -import numpy as np +import traceback from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from threading import Thread -from modules.utils import get_available_models -from modules.models import load_model, unload_model -from modules.models_settings import (get_model_settings_from_yamls, - update_model_parameters) +import extensions.openai.completions as OAIcompletions +import extensions.openai.edits as OAIedits +import extensions.openai.embeddings as OAIembeddings +import extensions.openai.images as OAIimages +import extensions.openai.models as OAImodels +import extensions.openai.moderations as OAImoderations +from extensions.openai.defaults import clamp, default, get_default_req_params +from extensions.openai.errors import ( + InvalidRequestError, + OpenAIError, + ServiceUnavailableError +) +from extensions.openai.tokens import token_count, token_decode, token_encode +from extensions.openai.utils import debug_msg from modules import shared -from modules.text_generation import encode, generate_reply + +import cgi +import speech_recognition as sr +from pydub import AudioSegment params = { - 'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001, + # default params + 'port': 5001, + 'embedding_device': 'cpu', + 'embedding_model': 'all-mpnet-base-v2', + + # optional params + 'sd_webui_url': '', + 'debug': 0 } -debug = True if 'OPENEDAI_DEBUG' in os.environ else False - -# Slightly different defaults for OpenAI's API -# Data type is important, Ex. use 0.0 for a float 0 -default_req_params = { - 'max_new_tokens': 200, - 'temperature': 1.0, - 'top_p': 1.0, - 'top_k': 1, - 'repetition_penalty': 1.18, - 'repetition_penalty_range': 0, - 'encoder_repetition_penalty': 1.0, - 'suffix': None, - 'stream': False, - 'echo': False, - 'seed': -1, - # 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map - 'truncation_length': 2048, - 'add_bos_token': True, - 'do_sample': True, - 'typical_p': 1.0, - 'epsilon_cutoff': 0.0, # In units of 1e-4 - 'eta_cutoff': 0.0, # In units of 1e-4 - 'tfs': 1.0, - 'top_a': 0.0, - 'min_length': 0, - 'no_repeat_ngram_size': 0, - 'num_beams': 1, - 'penalty_alpha': 0.0, - 'length_penalty': 1.0, - 'early_stopping': False, - 'mirostat_mode': 0, - 'mirostat_tau': 5.0, - 'mirostat_eta': 0.1, - 'ban_eos_token': False, - 'skip_special_tokens': True, - 'custom_stopping_strings': '', -} - -# Optional, install the module and download the model to enable -# v1/embeddings -try: - from sentence_transformers import SentenceTransformer -except ImportError: - pass - -st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" -embedding_model = None - -# little helper to get defaults if arg is present but None and should be the same type as default. -def default(dic, key, default): - val = dic.get(key, default) - if type(val) != type(default): - # maybe it's just something like 1 instead of 1.0 - try: - v = type(default)(val) - if type(val)(v) == val: # if it's the same value passed in, it's ok. - return v - except: - pass - - val = default - return val - - -def clamp(value, minvalue, maxvalue): - return max(minvalue, min(value, maxvalue)) - - -def float_list_to_base64(float_list): - # Convert the list to a float32 array that the OpenAPI client expects - float_array = np.array(float_list, dtype="float32") - - # Get raw bytes - bytes_array = float_array.tobytes() - - # Encode bytes into base64 - encoded_bytes = base64.b64encode(bytes_array) - - # Turn raw base64 encoded bytes into ASCII - ascii_string = encoded_bytes.decode('ascii') - return ascii_string - - class Handler(BaseHTTPRequestHandler): def send_access_control_headers(self): self.send_header("Access-Control-Allow-Origin", "*") @@ -118,11 +50,48 @@ class Handler(BaseHTTPRequestHandler): "Authorization" ) - def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''): - self.send_response(code) + def do_OPTIONS(self): + self.send_response(200) self.send_access_control_headers() self.send_header('Content-Type', 'application/json') self.end_headers() + self.wfile.write("OK".encode('utf-8')) + + def start_sse(self): + self.send_response(200) + self.send_access_control_headers() + self.send_header('Content-Type', 'text/event-stream') + self.send_header('Cache-Control', 'no-cache') + # self.send_header('Connection', 'keep-alive') + self.end_headers() + + def send_sse(self, chunk: dict): + response = 'data: ' + json.dumps(chunk) + '\r\n\r\n' + debug_msg(response[:-4]) + self.wfile.write(response.encode('utf-8')) + + def end_sse(self): + response = 'data: [DONE]\r\n\r\n' + debug_msg(response[:-4]) + self.wfile.write(response.encode('utf-8')) + + def return_json(self, ret: dict, code: int = 200, no_debug=False): + self.send_response(code) + self.send_access_control_headers() + self.send_header('Content-Type', 'application/json') + + response = json.dumps(ret) + r_utf8 = response.encode('utf-8') + + self.send_header('Content-Length', str(len(r_utf8))) + self.end_headers() + + self.wfile.write(r_utf8) + if not no_debug: + debug_msg(r_utf8) + + def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''): + error_resp = { 'error': { 'message': message, @@ -132,756 +101,237 @@ class Handler(BaseHTTPRequestHandler): } } if internal_message: - error_resp['internal_message'] = internal_message + print(error_type, message) + print(internal_message) + # error_resp['internal_message'] = internal_message - response = json.dumps(error_resp) - self.wfile.write(response.encode('utf-8')) + self.return_json(error_resp, code) - def do_OPTIONS(self): - self.send_response(200) - self.send_access_control_headers() - self.send_header('Content-Type', 'application/json') - self.end_headers() - self.wfile.write("OK".encode('utf-8')) + def openai_error_handler(func): + def wrapper(self): + try: + func(self) + except InvalidRequestError as e: + self.openai_error(e.message, e.code, e.__class__.__name__, e.param, internal_message=e.internal_message) + except OpenAIError as e: + self.openai_error(e.message, e.code, e.__class__.__name__, internal_message=e.internal_message) + except Exception as e: + self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc()) + return wrapper + + @openai_error_handler def do_GET(self): - if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'): - current_model_list = [ shared.model_name ] # The real chat/completions model, maybe "None" - embeddings_model_list = [ st_model ] if embedding_model else [] # The real sentence transformer embeddings model - pseudo_model_list = [ # these are expected by so much, so include some here as a dummy - 'gpt-3.5-turbo', # /v1/chat/completions - 'text-curie-001', # /v1/completions, 2k context - 'text-davinci-002' # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768 - ] + debug_msg(self.requestline) + debug_msg(self.headers) + if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'): is_legacy = 'engines' in self.path is_list = self.path in ['/v1/engines', '/v1/models'] - - resp = '' - - if is_legacy and not is_list: # load model + if is_legacy and not is_list: model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):] - - resp = { - "id": model_name, - "object": "engine", - "owner": "self", - "ready": True, - } - if model_name not in pseudo_model_list + embeddings_model_list + current_model_list: # Real model only - # No args. Maybe it works anyways! - # TODO: hack some heuristics into args for better results - - shared.model_name = model_name - unload_model() - - model_settings = get_model_settings_from_yamls(shared.model_name) - shared.settings.update(model_settings) - update_model_parameters(model_settings, initial=True) - - if shared.settings['mode'] != 'instruct': - shared.settings['instruction_template'] = None - - shared.model, shared.tokenizer = load_model(shared.model_name) - - if not shared.model: # load failed. - shared.model_name = "None" - resp['id'] = "None" - resp['ready'] = False - + resp = OAImodels.load_model(model_name) elif is_list: - # TODO: Lora's? - available_model_list = get_available_models() - all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list - - models = {} - - if is_legacy: - models = [{ "id": id, "object": "engine", "owner": "user", "ready": True } for id in all_model_list ] - if not shared.model: - models[0]['ready'] = False - else: - models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ] - - resp = { - "object": "list", - "data": models, - } - + resp = OAImodels.list_models(is_legacy) else: - the_model_name = self.path[len('/v1/models/'):] - resp = { - "id": the_model_name, - "object": "model", - "owned_by": "user", - "permission": [] - } + model_name = self.path[len('/v1/models/'):] + resp = OAImodels.model_info(model_name) - self.send_response(200) - self.send_access_control_headers() - self.send_header('Content-Type', 'application/json') - self.end_headers() - response = json.dumps(resp) - self.wfile.write(response.encode('utf-8')) + self.return_json(resp) elif '/billing/usage' in self.path: - # Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31 - self.send_response(200) - self.send_access_control_headers() - self.send_header('Content-Type', 'application/json') - self.end_headers() - - response = json.dumps({ - "total_usage": 0, - }) - self.wfile.write(response.encode('utf-8')) + # Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31 + self.return_json({"total_usage": 0}, no_debug=True) else: self.send_error(404) + @openai_error_handler def do_POST(self): - if debug: - print(self.headers) # did you know... python-openai sends your linux kernel & python version? - content_length = int(self.headers['Content-Length']) - body = json.loads(self.rfile.read(content_length).decode('utf-8')) - if debug: - print(body) + if '/v1/audio/transcriptions' in self.path: + r = sr.Recognizer() + + # Parse the form data + form = cgi.FieldStorage( + fp=self.rfile, + headers=self.headers, + environ={'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': self.headers['Content-Type']} + ) + + audio_file = form['file'].file + audio_data = AudioSegment.from_file(audio_file) + + # Convert AudioSegment to raw data + raw_data = audio_data.raw_data + + # Create AudioData object + audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width) + whipser_language = form.getvalue('language', None) + whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny + + transcription = {"text": ""} + + try: + transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model) + except sr.UnknownValueError: + print("Whisper could not understand audio") + transcription["text"] = "Whisper could not understand audio UnknownValueError" + except sr.RequestError as e: + print("Could not request results from Whisper", e) + transcription["text"] = "Whisper could not understand audio RequestError" + + self.return_json(transcription, no_debug=True) + return + + debug_msg(self.requestline) + debug_msg(self.headers) + + content_length = self.headers.get('Content-Length') + transfer_encoding = self.headers.get('Transfer-Encoding') + + if content_length: + body = json.loads(self.rfile.read(int(content_length)).decode('utf-8')) + elif transfer_encoding == 'chunked': + chunks = [] + while True: + chunk_size = int(self.rfile.readline(), 16) # Read the chunk size + if chunk_size == 0: + break # End of chunks + chunks.append(self.rfile.read(chunk_size)) + self.rfile.readline() # Consume the trailing newline after each chunk + body = json.loads(b''.join(chunks).decode('utf-8')) + else: + self.send_response(400, "Bad Request: Either Content-Length or Transfer-Encoding header expected.") + self.end_headers() + return + + debug_msg(body) if '/completions' in self.path or '/generate' in self.path: if not shared.model: - self.openai_error("No model loaded.") - return + raise ServiceUnavailableError("No model loaded.") is_legacy = '/generate' in self.path - is_chat_request = 'chat' in self.path - resp_list = 'data' if is_legacy else 'choices' - - # XXX model is ignored for now - # model = body.get('model', shared.model_name) # ignored, use existing for now - model = shared.model_name - created_time = int(time.time()) - - cmpl_id = "chatcmpl-%d" % (created_time) if is_chat_request else "conv-%d" % (created_time) - - # Request Parameters - # Try to use openai defaults or map them to something with the same intent - req_params = default_req_params.copy() - stopping_strings = [] - - if 'stop' in body: - if isinstance(body['stop'], str): - stopping_strings.extend([body['stop']]) - elif isinstance(body['stop'], list): - stopping_strings.extend(body['stop']) - - truncation_length = default(shared.settings, 'truncation_length', 2048) - truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length) - - default_max_tokens = truncation_length if is_chat_request else 16 # completions default, chat default is 'inf' so we need to cap it. - - max_tokens_str = 'length' if is_legacy else 'max_tokens' - max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens)) - # if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough - - req_params['max_new_tokens'] = max_tokens - req_params['truncation_length'] = truncation_length - req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0 - req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0) - req_params['top_k'] = default(body, 'best_of', default_req_params['top_k']) - req_params['suffix'] = default(body, 'suffix', default_req_params['suffix']) - req_params['stream'] = default(body, 'stream', default_req_params['stream']) - req_params['echo'] = default(body, 'echo', default_req_params['echo']) - req_params['seed'] = shared.settings.get('seed', default_req_params['seed']) - req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token']) - - is_streaming = req_params['stream'] - - self.send_response(200) - self.send_access_control_headers() - if is_streaming: - self.send_header('Content-Type', 'text/event-stream') - self.send_header('Cache-Control', 'no-cache') - # self.send_header('Connection', 'keep-alive') - else: - self.send_header('Content-Type', 'application/json') - self.end_headers() - - token_count = 0 - completion_token_count = 0 - prompt = '' - stream_object_type = '' - object_type = '' - - if is_chat_request: - # Chat Completions - stream_object_type = 'chat.completions.chunk' - object_type = 'chat.completions' - - messages = body['messages'] - - role_formats = { - 'user': 'user: {message}\n', - 'assistant': 'assistant: {message}\n', - 'system': '{message}', - 'context': 'You are a helpful assistant. Answer as concisely as possible.', - 'prompt': 'assistant:', - } - - # Instruct models can be much better - if shared.settings['instruction_template']: - try: - instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r')) - - template = instruct['turn_template'] - system_message_template = "{message}" - system_message_default = instruct['context'] - bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token - user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user']) - bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot']) - bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ') - - role_formats = { - 'user': user_message_template, - 'assistant': bot_message_template, - 'system': system_message_template, - 'context': system_message_default, - 'prompt': bot_prompt, - } - - if 'Alpaca' in shared.settings['instruction_template']: - stopping_strings.extend(['\n###']) - elif instruct['user']: # WizardLM and some others have no user prompt. - stopping_strings.extend(['\n' + instruct['user'], instruct['user']]) - - if debug: - print(f"Loaded instruction role format: {shared.settings['instruction_template']}") - - except Exception as e: - stopping_strings.extend(['\nuser:']) - - print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}") - print("Warning: Loaded default instruction-following template for model.") - - else: - stopping_strings.extend(['\nuser:']) - print("Warning: Loaded default instruction-following template for model.") - - system_msgs = [] - chat_msgs = [] - - # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date} - context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else '' - if context_msg: - system_msgs.extend([context_msg]) - - # Maybe they sent both? This is not documented in the API, but some clients seem to do this. - if 'prompt' in body: - prompt_msg = role_formats['system'].format(message=body['prompt']) - system_msgs.extend([prompt_msg]) - - for m in messages: - role = m['role'] - content = m['content'] - msg = role_formats[role].format(message=content) - if role == 'system': - system_msgs.extend([msg]) - else: - chat_msgs.extend([msg]) - - # can't really truncate the system messages - system_msg = '\n'.join(system_msgs) - if system_msg and system_msg[-1] != '\n': - system_msg = system_msg + '\n' - - system_token_count = len(encode(system_msg)[0]) - remaining_tokens = truncation_length - system_token_count - chat_msg = '' - - while chat_msgs: - new_msg = chat_msgs.pop() - new_size = len(encode(new_msg)[0]) - if new_size <= remaining_tokens: - chat_msg = new_msg + chat_msg - remaining_tokens -= new_size - else: - print(f"Warning: too many messages for context size, dropping {len(chat_msgs) + 1} oldest message(s).") - break - - prompt = system_msg + chat_msg + role_formats['prompt'] - - token_count = len(encode(prompt)[0]) - - else: - # Text Completions - stream_object_type = 'text_completion.chunk' - object_type = 'text_completion' - - # ... encoded as a string, array of strings, array of tokens, or array of token arrays. - if is_legacy: - prompt = body['context'] # Older engines.generate API - else: - prompt = body['prompt'] # XXX this can be different types - - if isinstance(prompt, list): - self.openai_error("API Batched generation not yet supported.") - return - - token_count = len(encode(prompt)[0]) - if token_count >= truncation_length: - new_len = int(len(prompt) * shared.settings['truncation_length'] / token_count) - prompt = prompt[-new_len:] - new_token_count = len(encode(prompt)[0]) - print(f"Warning: truncating prompt to {new_len} characters, was {token_count} tokens. Now: {new_token_count} tokens.") - token_count = new_token_count - - if truncation_length - token_count < req_params['max_new_tokens']: - print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {truncation_length - token_count}") - req_params['max_new_tokens'] = truncation_length - token_count - print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}") + is_streaming = body.get('stream', False) if is_streaming: - # begin streaming - chunk = { - "id": cmpl_id, - "object": stream_object_type, - "created": created_time, - "model": shared.model_name, - resp_list: [{ - "index": 0, - "finish_reason": None, - }], - } + self.start_sse() - if stream_object_type == 'text_completion.chunk': - chunk[resp_list][0]["text"] = "" + response = [] + if 'chat' in self.path: + response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy) else: - # So yeah... do both methods? delta and messages. - chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''} - chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''} + response = OAIcompletions.stream_completions(body, is_legacy=is_legacy) - response = 'data: ' + json.dumps(chunk) + '\r\n\r\n' - self.wfile.write(response.encode('utf-8')) + for resp in response: + self.send_sse(resp) - # generate reply ####################################### - if debug: - print({'prompt': prompt, 'req_params': req_params}) - generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) + self.end_sse() - answer = '' - seen_content = '' - longest_stop_len = max([len(x) for x in stopping_strings] + [0]) - - for a in generator: - answer = a - - stop_string_found = False - len_seen = len(seen_content) - search_start = max(len_seen - longest_stop_len, 0) - - for string in stopping_strings: - idx = answer.find(string, search_start) - if idx != -1: - answer = answer[:idx] # clip it. - stop_string_found = True - - if stop_string_found: - break - - # If something like "\nYo" is generated just before "\nYou:" - # is completed, buffer and generate more, don't send it - buffer_and_continue = False - - for string in stopping_strings: - for j in range(len(string) - 1, 0, -1): - if answer[-j:] == string[:j]: - buffer_and_continue = True - break - else: - continue - break - - if buffer_and_continue: - continue - - if is_streaming: - # Streaming - new_content = answer[len_seen:] - - if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. - continue - - seen_content = answer - chunk = { - "id": cmpl_id, - "object": stream_object_type, - "created": created_time, - "model": shared.model_name, - resp_list: [{ - "index": 0, - "finish_reason": None, - }], - } - - # strip extra leading space off new generated content - if len_seen == 0 and new_content[0] == ' ': - new_content = new_content[1:] - - if stream_object_type == 'text_completion.chunk': - chunk[resp_list][0]['text'] = new_content - else: - # So yeah... do both methods? delta and messages. - chunk[resp_list][0]['message'] = {'content': new_content} - chunk[resp_list][0]['delta'] = {'content': new_content} - response = 'data: ' + json.dumps(chunk) + '\r\n\r\n' - self.wfile.write(response.encode('utf-8')) - completion_token_count += len(encode(new_content)[0]) - - if is_streaming: - chunk = { - "id": cmpl_id, - "object": stream_object_type, - "created": created_time, - "model": model, # TODO: add Lora info? - resp_list: [{ - "index": 0, - "finish_reason": "stop", - }], - "usage": { - "prompt_tokens": token_count, - "completion_tokens": completion_token_count, - "total_tokens": token_count + completion_token_count - } - } - if stream_object_type == 'text_completion.chunk': - chunk[resp_list][0]['text'] = '' - else: - # So yeah... do both methods? delta and messages. - chunk[resp_list][0]['message'] = {'content': ''} - chunk[resp_list][0]['delta'] = {'content': ''} - - response = 'data: ' + json.dumps(chunk) + '\r\n\r\ndata: [DONE]\r\n\r\n' - self.wfile.write(response.encode('utf-8')) - # Finished if streaming. - if debug: - if answer and answer[0] == ' ': - answer = answer[1:] - print({'answer': answer}, chunk) - return - - # strip extra leading space off new generated content - if answer and answer[0] == ' ': - answer = answer[1:] - - if debug: - print({'response': answer}) - - completion_token_count = len(encode(answer)[0]) - stop_reason = "stop" - if token_count + completion_token_count >= truncation_length: - stop_reason = "length" - - resp = { - "id": cmpl_id, - "object": object_type, - "created": created_time, - "model": model, # TODO: add Lora info? - resp_list: [{ - "index": 0, - "finish_reason": stop_reason, - }], - "usage": { - "prompt_tokens": token_count, - "completion_tokens": completion_token_count, - "total_tokens": token_count + completion_token_count - } - } - - if is_chat_request: - resp[resp_list][0]["message"] = {"role": "assistant", "content": answer} else: - resp[resp_list][0]["text"] = answer + response = '' + if 'chat' in self.path: + response = OAIcompletions.chat_completions(body, is_legacy=is_legacy) + else: + response = OAIcompletions.completions(body, is_legacy=is_legacy) - response = json.dumps(resp) - self.wfile.write(response.encode('utf-8')) + self.return_json(response) elif '/edits' in self.path: + # deprecated + if not shared.model: - self.openai_error("No model loaded.") - return + raise ServiceUnavailableError("No model loaded.") - self.send_response(200) - self.send_access_control_headers() - self.send_header('Content-Type', 'application/json') - self.end_headers() + req_params = get_default_req_params() - created_time = int(time.time()) - - # Using Alpaca format, this may work with other models too. instruction = body['instruction'] input = body.get('input', '') + temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0 + top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0) - # Request parameters - req_params = default_req_params.copy() - stopping_strings = [] + response = OAIedits.edits(instruction, input, temperature, top_p) - # Alpaca is verbose so a good default prompt - default_template = ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" - ) + self.return_json(response) - instruction_template = default_template - - # Use the special instruction/input/response template for anything trained like Alpaca - if shared.settings['instruction_template']: - if 'Alpaca' in shared.settings['instruction_template']: - stopping_strings.extend(['\n###']) - else: - try: - instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r')) + elif '/images/generations' in self.path: + if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')): + raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.") - template = instruct['turn_template'] - template = template\ - .replace('<|user|>', instruct.get('user', ''))\ - .replace('<|bot|>', instruct.get('bot', ''))\ - .replace('<|user-message|>', '{instruction}\n{input}') - - instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ') - if instruct['user']: - stopping_strings.extend(['\n' + instruct['user'], instruct['user'] ]) - - except Exception as e: - instruction_template = default_template - print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}") - print("Warning: Loaded default instruction-following template (Alpaca) for model.") - else: - stopping_strings.extend(['\n###']) - print("Warning: Loaded default instruction-following template (Alpaca) for model.") - - - edit_task = instruction_template.format(instruction=instruction, input=input) - - truncation_length = default(shared.settings, 'truncation_length', 2048) - token_count = len(encode(edit_task)[0]) - max_tokens = truncation_length - token_count - - req_params['max_new_tokens'] = max_tokens - req_params['truncation_length'] = truncation_length - req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0 - req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0) - req_params['seed'] = shared.settings.get('seed', default_req_params['seed']) - req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token']) - - if debug: - print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count}) - - generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False) - - longest_stop_len = max([len(x) for x in stopping_strings] + [0]) - answer = '' - seen_content = '' - for a in generator: - answer = a - - stop_string_found = False - len_seen = len(seen_content) - search_start = max(len_seen - longest_stop_len, 0) - - for string in stopping_strings: - idx = answer.find(string, search_start) - if idx != -1: - answer = answer[:idx] # clip it. - stop_string_found = True - - if stop_string_found: - break - - - # some reply's have an extra leading space to fit the instruction template, just clip it off from the reply. - if edit_task[-1] != '\n' and answer and answer[0] == ' ': - answer = answer[1:] - - completion_token_count = len(encode(answer)[0]) - - resp = { - "object": "edit", - "created": created_time, - "choices": [{ - "text": answer, - "index": 0, - }], - "usage": { - "prompt_tokens": token_count, - "completion_tokens": completion_token_count, - "total_tokens": token_count + completion_token_count - } - } - - if debug: - print({'answer': answer, 'completion_token_count': completion_token_count}) - - response = json.dumps(resp) - self.wfile.write(response.encode('utf-8')) - - elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ: - # Stable Diffusion callout wrapper for txt2img - # Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E - # the results will be limited and likely poor. SD has hundreds of models and dozens of settings. - # If you want high quality tailored results you should just use the Stable Diffusion API directly. - # it's too general an API to try and shape the result with specific tags like "masterpiece", etc, - # Will probably work best with the stock SD models. - # SD configuration is beyond the scope of this API. - # At this point I will not add the edits and variations endpoints (ie. img2img) because they - # require changing the form data handling to accept multipart form data, also to properly support - # url return types will require file management and a web serving files... Perhaps later! - - self.send_response(200) - self.send_access_control_headers() - self.send_header('Content-Type', 'application/json') - self.end_headers() - - width, height = [ int(x) for x in default(body, 'size', '1024x1024').split('x') ] # ignore the restrictions on size + prompt = body['prompt'] + size = default(body, 'size', '1024x1024') response_format = default(body, 'response_format', 'url') # or b64_json - - payload = { - 'prompt': body['prompt'], # ignore prompt limit of 1000 characters - 'width': width, - 'height': height, - 'batch_size': default(body, 'n', 1) # ignore the batch limits of max 10 - } + n = default(body, 'n', 1) # ignore the batch limits of max 10 - resp = { - 'created': int(time.time()), - 'data': [] - } + response = OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n) - # TODO: support SD_WEBUI_AUTH username:password pair. - sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img" + self.return_json(response, no_debug=True) - response = requests.post(url=sd_url, json=payload) - r = response.json() - # r['parameters']... - for b64_json in r['images']: - if response_format == 'b64_json': - resp['data'].extend([{'b64_json': b64_json}]) - else: - resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this + elif '/embeddings' in self.path: + encoding_format = body.get('encoding_format', '') - response = json.dumps(resp) - self.wfile.write(response.encode('utf-8')) + input = body.get('input', body.get('text', '')) + if not input: + raise InvalidRequestError("Missing required argument input", params='input') - elif '/embeddings' in self.path and embedding_model is not None: - self.send_response(200) - self.send_access_control_headers() - self.send_header('Content-Type', 'application/json') - self.end_headers() - - input = body['input'] if 'input' in body else body['text'] if type(input) is str: input = [input] - embeddings = embedding_model.encode(input).tolist() + response = OAIembeddings.embeddings(input, encoding_format) - def enc_emb(emb): - # If base64 is specified, encode. Otherwise, do nothing. - if body.get("encoding_format", "") == "base64": - return float_list_to_base64(emb) - else: - return emb - data = [{"object": "embedding", "embedding": enc_emb(emb), "index": n} for n, emb in enumerate(embeddings)] - - response = json.dumps({ - "object": "list", - "data": data, - "model": st_model, # return the real model - "usage": { - "prompt_tokens": 0, - "total_tokens": 0, - } - }) - - if debug: - print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}") - self.wfile.write(response.encode('utf-8')) + self.return_json(response, no_debug=True) elif '/moderations' in self.path: - # for now do nothing, just don't error. - self.send_response(200) - self.send_access_control_headers() - self.send_header('Content-Type', 'application/json') - self.end_headers() + input = body['input'] + if not input: + raise InvalidRequestError("Missing required argument input", params='input') - response = json.dumps({ - "id": "modr-5MWoLO", - "model": "text-moderation-001", - "results": [{ - "categories": { - "hate": False, - "hate/threatening": False, - "self-harm": False, - "sexual": False, - "sexual/minors": False, - "violence": False, - "violence/graphic": False - }, - "category_scores": { - "hate": 0.0, - "hate/threatening": 0.0, - "self-harm": 0.0, - "sexual": 0.0, - "sexual/minors": 0.0, - "violence": 0.0, - "violence/graphic": 0.0 - }, - "flagged": False - }] - }) - self.wfile.write(response.encode('utf-8')) + response = OAImoderations.moderations(input) + + self.return_json(response, no_debug=True) elif self.path == '/api/v1/token-count': # NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side. - self.send_response(200) - self.send_access_control_headers() - self.send_header('Content-Type', 'application/json') - self.end_headers() + response = token_count(body['prompt']) - tokens = encode(body['prompt'])[0] - response = json.dumps({ - 'results': [{ - 'tokens': len(tokens) - }] - }) - self.wfile.write(response.encode('utf-8')) + self.return_json(response, no_debug=True) + + elif self.path == '/api/v1/token/encode': + # NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models + encoding_format = body.get('encoding_format', '') + + response = token_encode(body['input'], encoding_format) + + self.return_json(response, no_debug=True) + + elif self.path == '/api/v1/token/decode': + # NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models + encoding_format = body.get('encoding_format', '') + + response = token_decode(body['input'], encoding_format) + + self.return_json(response, no_debug=True) else: - print(self.path, self.headers) self.send_error(404) def run_server(): - global embedding_model - try: - embedding_model = SentenceTransformer(st_model) - print(f"\nLoaded embedding model: {st_model}, max sequence length: {embedding_model.max_seq_length}") - except: - print(f"\nFailed to load embedding model: {st_model}") - pass - - server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port']) + port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001))) + server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port) server = ThreadingHTTPServer(server_addr, Handler) if shared.args.share: try: from flask_cloudflared import _run_cloudflared - public_url = _run_cloudflared(params['port'], params['port'] + 1) - print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1') + public_url = _run_cloudflared(port, port + 1) + print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1') except ImportError: print('You should install flask_cloudflared manually') else: - print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1') - + print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1') + server.serve_forever() diff --git a/extensions/openai/tokens.py b/extensions/openai/tokens.py new file mode 100644 index 0000000..0338e7f --- /dev/null +++ b/extensions/openai/tokens.py @@ -0,0 +1,36 @@ +from modules.text_generation import decode, encode + + +def token_count(prompt): + tokens = encode(prompt)[0] + + return { + 'results': [{ + 'tokens': len(tokens) + }] + } + + +def token_encode(input, encoding_format): + # if isinstance(input, list): + tokens = encode(input)[0] + + return { + 'results': [{ + 'tokens': tokens, + 'length': len(tokens), + }] + } + + +def token_decode(tokens, encoding_format): + # if isinstance(input, list): + # if encoding_format == "base64": + # tokens = base64_to_float_list(tokens) + output = decode(tokens)[0] + + return { + 'results': [{ + 'text': output + }] + } diff --git a/extensions/openai/utils.py b/extensions/openai/utils.py new file mode 100644 index 0000000..49fc951 --- /dev/null +++ b/extensions/openai/utils.py @@ -0,0 +1,30 @@ +import base64 +import os + +import numpy as np + +def float_list_to_base64(float_array: np.ndarray) -> str: + # Convert the list to a float32 array that the OpenAPI client expects + # float_array = np.array(float_list, dtype="float32") + + # Get raw bytes + bytes_array = float_array.tobytes() + + # Encode bytes into base64 + encoded_bytes = base64.b64encode(bytes_array) + + # Turn raw base64 encoded bytes into ASCII + ascii_string = encoded_bytes.decode('ascii') + return ascii_string + + +def end_line(s): + if s and s[-1] != '\n': + s = s + '\n' + return s + + +def debug_msg(*args, **kwargs): + from extensions.openai.script import params + if os.environ.get("OPENEDAI_DEBUG", params.get('debug', 0)): + print(*args, **kwargs) diff --git a/extensions/perplexity_colors/script.py b/extensions/perplexity_colors/script.py new file mode 100644 index 0000000..2a986ac --- /dev/null +++ b/extensions/perplexity_colors/script.py @@ -0,0 +1,309 @@ +import time + +import gradio +import numpy as np +import torch +from transformers import LogitsProcessor + +from modules import html_generator, shared + +params = { + 'active': True, + 'color_by_perplexity': False, + 'color_by_probability': False, + 'ppl_scale': 15.0, # No slider for this right now, because I don't think it really needs to be changed. Very large perplexity scores don't show up often. + 'probability_dropdown': False, + 'verbose': False # For debugging mostly +} + + +class PerplexityLogits(LogitsProcessor): + def __init__(self, verbose=False): + self.generated_token_ids = [] + self.selected_probs = [] + self.top_token_ids_list = [] + self.top_probs_list = [] + self.perplexities_list = [] + self.last_probs = None + self.verbose = verbose + + def __call__(self, input_ids, scores): + # t0 = time.time() + probs = torch.softmax(scores, dim=-1, dtype=torch.float) + log_probs = torch.nan_to_num(torch.log(probs)) # Note: This is to convert log(0) nan to 0, but probs*log_probs makes this 0 not affect the perplexity. + entropy = -torch.sum(probs * log_probs) + entropy = entropy.cpu().numpy() + perplexity = round(float(np.exp(entropy)), 4) + self.perplexities_list.append(perplexity) + last_token_id = int(input_ids[0][-1].cpu().numpy().item()) + # Store the generated tokens (not sure why this isn't accessible in the output endpoint!) + self.generated_token_ids.append(last_token_id) + # Get last probability, and add to the list if it wasn't there + if len(self.selected_probs) > 0: + # Is the selected token in the top tokens? + if self.verbose: + print('Probs: Token after', shared.tokenizer.decode(last_token_id)) + print('Probs:', [shared.tokenizer.decode(token_id) for token_id in self.top_token_ids_list[-1][0]]) + print('Probs:', [round(float(prob), 4) for prob in self.top_probs_list[-1][0]]) + if last_token_id in self.top_token_ids_list[-1][0]: + idx = self.top_token_ids_list[-1][0].index(last_token_id) + self.selected_probs.append(self.top_probs_list[-1][0][idx]) + else: + self.top_token_ids_list[-1][0].append(last_token_id) + last_prob = round(float(self.last_probs[last_token_id]), 4) + self.top_probs_list[-1][0].append(last_prob) + self.selected_probs.append(last_prob) + else: + self.selected_probs.append(1.0) # Placeholder for the last token of the prompt + + if self.verbose: + pplbar = "-" + if not np.isnan(perplexity): + pplbar = "*" * round(perplexity) + print(f"PPL: Token after {shared.tokenizer.decode(last_token_id)}\t{perplexity:.2f}\t{pplbar}") + + # Get top 5 probabilities + top_tokens_and_probs = torch.topk(probs, 5) + top_probs = top_tokens_and_probs.values.cpu().numpy().astype(float).tolist() + top_token_ids = top_tokens_and_probs.indices.cpu().numpy().astype(int).tolist() + + self.top_token_ids_list.append(top_token_ids) + self.top_probs_list.append(top_probs) + + probs = probs.cpu().numpy().flatten() + self.last_probs = probs # Need to keep this as a reference for top probs + + # t1 = time.time() + # print(f"PPL Processor: {(t1-t0):.3f} s") + # About 1 ms, though occasionally up to around 100 ms, not sure why... + # Doesn't actually modify the logits! + return scores + + +# Stores the perplexity and top probabilities +ppl_logits_processor = None + + +def logits_processor_modifier(logits_processor_list, input_ids): + global ppl_logits_processor + if params['active']: + ppl_logits_processor = PerplexityLogits(verbose=params['verbose']) + logits_processor_list.append(ppl_logits_processor) + + +def output_modifier(text): + global ppl_logits_processor + # t0 = time.time() + + if not params['active']: + return text + + # TODO: It's probably more efficient to do this above rather than modifying all these lists + # Remove last element of perplexities_list, top_token_ids_list, top_tokens_list, top_probs_list since everything is off by one because this extension runs before generation + perplexities = ppl_logits_processor.perplexities_list[:-1] + top_token_ids_list = ppl_logits_processor.top_token_ids_list[:-1] + top_tokens_list = [[shared.tokenizer.decode(token_id) for token_id in top_token_ids[0]] for top_token_ids in top_token_ids_list] + top_probs_list = ppl_logits_processor.top_probs_list[:-1] + # Remove first element of generated_token_ids, generated_tokens, selected_probs because they are for the last token of the prompt + gen_token_ids = ppl_logits_processor.generated_token_ids[1:] + gen_tokens = [shared.tokenizer.decode(token_id) for token_id in gen_token_ids] + sel_probs = ppl_logits_processor.selected_probs[1:] + + end_part = '' if params['probability_dropdown'] else '' # Helps with finding the index after replacing part of the text. + + i = 0 + for token, prob, ppl, top_tokens, top_probs in zip(gen_tokens, sel_probs, perplexities, top_tokens_list, top_probs_list): + color = 'ffffff' + if params['color_by_probability'] and params['color_by_perplexity']: + color = probability_perplexity_color_scale(prob, ppl) + elif params['color_by_perplexity']: + color = perplexity_color_scale(ppl) + elif params['color_by_probability']: + color = probability_color_scale(prob) + if token in text[i:]: + if params['probability_dropdown']: + text = text[:i] + text[i:].replace(token, add_dropdown_html(token, color, top_tokens, top_probs[0], ppl), 1) + else: + text = text[:i] + text[i:].replace(token, add_color_html(token, color), 1) + i += text[i:].find(end_part) + len(end_part) + + # Use full perplexity list for calculating the average here. + print('Average perplexity:', round(np.mean(ppl_logits_processor.perplexities_list[:-1]), 4)) + # t1 = time.time() + # print(f"Modifier: {(t1-t0):.3f} s") + # About 50 ms + return text + + +def probability_color_scale(prob): + ''' + Green-yellow-red color scale + ''' + + rv = 0 + gv = 0 + if prob <= 0.5: + rv = 'ff' + gv = hex(int(255 * prob * 2))[2:] + if len(gv) < 2: + gv = '0' * (2 - len(gv)) + gv + else: + rv = hex(int(255 - 255 * (prob - 0.5) * 2))[2:] + gv = 'ff' + if len(rv) < 2: + rv = '0' * (2 - len(rv)) + rv + + return rv + gv + '00' + + +def perplexity_color_scale(ppl): + ''' + Red component only, white for 0 perplexity (sorry if you're not in dark mode) + ''' + value = hex(max(int(255.0 - params['ppl_scale'] * (float(ppl) - 1.0)), 0))[2:] + if len(value) < 2: + value = '0' * (2 - len(value)) + value + + return 'ff' + value + value + + +def probability_perplexity_color_scale(prob, ppl): + ''' + Green-yellow-red for probability and blue component for perplexity + ''' + + rv = 0 + gv = 0 + bv = hex(min(max(int(params['ppl_scale'] * (float(ppl) - 1.0)), 0), 255))[2:] + if len(bv) < 2: + bv = '0' * (2 - len(bv)) + bv + + if prob <= 0.5: + rv = 'ff' + gv = hex(int(255 * prob * 2))[2:] + if len(gv) < 2: + gv = '0' * (2 - len(gv)) + gv + else: + rv = hex(int(255 - 255 * (prob - 0.5) * 2))[2:] + gv = 'ff' + if len(rv) < 2: + rv = '0' * (2 - len(rv)) + rv + + return rv + gv + bv + + +def add_color_html(token, color): + return f'{token}' + + +# TODO: Major issue: Applying this to too many tokens will cause a permanent slowdown in generation speed until the messages are removed from the history. +# I think the issue is from HTML elements taking up space in the visible history, and things like history deepcopy add latency proportional to the size of the history. +# Potential solution is maybe to modify the main generation code to send just the internal text and not the visible history, to avoid moving too much around. +# I wonder if we can also avoid using deepcopy here. +def add_dropdown_html(token, color, top_tokens, top_probs, perplexity=0): + html = f'
{token}
' + return html # About 750 characters per token... + + +def custom_css(): + return """ + .dropdown { + display: none; + position: absolute; + z-index: 50; + background-color: var(--block-background-fill); + box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2); + width: max-content; + overflow: visible; + padding: 5px; + border-radius: 10px; + border: 1px solid var(--border-color-primary); + } + + .dropdown-content { + border: none; + z-index: 50; + } + + .dropdown-content tr.selected { + background-color: var(--block-label-background-fill); + } + + .dropdown-content td { + color: var(--body-text-color); + } + + .hoverable { + color: var(--body-text-color); + position: relative; + display: inline-block; + overflow: visible; + font-size: 15px; + line-height: 1.75; + margin: 0; + padding: 0; + } + + .hoverable:hover .dropdown { + display: block; + } + + pre { + white-space: pre-wrap; + } + + # TODO: This makes the hover menus extend outside the bounds of the chat area, which is good. + # However, it also makes the scrollbar disappear, which is bad. + # The scroll bar needs to still be present. So for now, we can't see dropdowns that extend past the edge of the chat area. + #.chat { + # overflow-y: auto; + #} + """ + + +# Monkeypatch applied to html_generator.py +# We simply don't render markdown into HTML. We wrap everything in
 tags to preserve whitespace
+# formatting. If you're coloring tokens by perplexity or probability, or especially if you're using
+# the probability dropdown, you probably care more about seeing the tokens the model actually outputted
+# rather than rendering ```code blocks``` or *italics*.
+def convert_to_markdown(string):
+    return '
' + string + '
' + + +html_generator.convert_to_markdown = convert_to_markdown + + +def ui(): + def update_active_check(x): + params.update({'active': x}) + + def update_color_by_ppl_check(x): + params.update({'color_by_perplexity': x}) + + def update_color_by_prob_check(x): + params.update({'color_by_probability': x}) + + def update_prob_dropdown_check(x): + params.update({'probability_dropdown': x}) + + active_check = gradio.Checkbox(value=True, label="Compute probabilities and perplexity scores", info="Activate this extension. Note that this extension currently does not work with exllama or llama.cpp.") + color_by_ppl_check = gradio.Checkbox(value=False, label="Color by perplexity", info="Higher perplexity is more red. If also showing probability, higher perplexity has more blue component.") + color_by_prob_check = gradio.Checkbox(value=False, label="Color by probability", info="Green-yellow-red linear scale, with 100% green, 50% yellow, 0% red.") + prob_dropdown_check = gradio.Checkbox(value=False, label="Probability dropdown", info="Hover over a token to show a dropdown of top token probabilities. Currently slightly buggy with whitespace between tokens.") + + active_check.change(update_active_check, active_check, None) + color_by_ppl_check.change(update_color_by_ppl_check, color_by_ppl_check, None) + color_by_prob_check.change(update_color_by_prob_check, color_by_prob_check, None) + prob_dropdown_check.change(update_prob_dropdown_check, prob_dropdown_check, None) diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py index 78488cd..e33367d 100644 --- a/extensions/sd_api_pictures/script.py +++ b/extensions/sd_api_pictures/script.py @@ -133,6 +133,9 @@ def get_SD_pictures(description, character): if params['manage_VRAM']: give_VRAM_priority('SD') + description = re.sub('', ' ', description) + description = f"({description}:1)" + payload = { "prompt": params['prompt_prefix'] + description, "seed": params['seed'], @@ -332,8 +335,8 @@ def ui(): negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt') with gr.Row(): with gr.Column(): - width = gr.Slider(256, 768, value=params['width'], step=64, label='Width') - height = gr.Slider(256, 768, value=params['height'], step=64, label='Height') + width = gr.Slider(64, 2048, value=params['width'], step=64, label='Width') + height = gr.Slider(64, 2048, value=params['height'], step=64, label='Height') with gr.Column(variant="compact", elem_id="sampler_col"): with gr.Row(elem_id="sampler_row"): sampler_name = gr.Dropdown(value=params['sampler_name'], label='Sampling method', elem_id="sampler_box") diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py index dbbeb0f..f8e6c96 100644 --- a/extensions/send_pictures/script.py +++ b/extensions/send_pictures/script.py @@ -5,11 +5,10 @@ import gradio as gr import torch from transformers import BlipForConditionalGeneration, BlipProcessor -from modules import chat, shared +from modules import chat, shared, ui_chat from modules.ui import gather_interface_values +from modules.utils import gradio -# If 'state' is True, will hijack the next chat generation with -# custom input text given by 'value' in the format [text, visible_text] input_hijack = { 'state': False, 'value': ["", ""] @@ -19,6 +18,15 @@ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu") +def chat_input_modifier(text, visible_text, state): + global input_hijack + if input_hijack['state']: + input_hijack['state'] = False + return input_hijack['value'] + else: + return text, visible_text + + def caption_image(raw_image): inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32) out = model.generate(**inputs, max_new_tokens=100) @@ -41,7 +49,10 @@ def ui(): # Prepare the input hijack, update the interface values, call the generation function, and clear the picture picture_select.upload( - lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then( - gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - chat.generate_chat_reply_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then( + lambda picture, name1, name2: input_hijack.update({ + "state": True, + "value": generate_chat_picture(picture, name1, name2) + }), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then( + gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + chat.generate_chat_reply_wrapper, gradio(ui_chat.inputs), gradio('display', 'history'), show_progress=False).then( lambda: None, None, picture_select, show_progress=False) diff --git a/extensions/silero_tts/harvard_sentences.txt b/extensions/silero_tts/harvard_sentences.txt new file mode 100644 index 0000000..958d7f3 --- /dev/null +++ b/extensions/silero_tts/harvard_sentences.txt @@ -0,0 +1,720 @@ +The birch canoe slid on the smooth planks. +Glue the sheet to the dark blue background. +It's easy to tell the depth of a well. +These days a chicken leg is a rare dish. +Rice is often served in round bowls. +The juice of lemons makes fine punch. +The box was thrown beside the parked truck. +The hogs were fed chopped corn and garbage. +Four hours of steady work faced us. +A large size in stockings is hard to sell. +The boy was there when the sun rose. +A rod is used to catch pink salmon. +The source of the huge river is the clear spring. +Kick the ball straight and follow through. +Help the woman get back to her feet. +A pot of tea helps to pass the evening. +Smoky fires lack flame and heat. +The soft cushion broke the man's fall. +The salt breeze came across from the sea. +The girl at the booth sold fifty bonds. +The small pup gnawed a hole in the sock. +The fish twisted and turned on the bent hook. +Press the pants and sew a button on the vest. +The swan dive was far short of perfect. +The beauty of the view stunned the young boy. +Two blue fish swam in the tank. +Her purse was full of useless trash. +The colt reared and threw the tall rider. +It snowed, rained, and hailed the same morning. +Read verse out loud for pleasure. +Hoist the load to your left shoulder. +Take the winding path to reach the lake. +Note closely the size of the gas tank. +Wipe the grease off his dirty face. +Mend the coat before you go out. +The wrist was badly strained and hung limp. +The stray cat gave birth to kittens. +The young girl gave no clear response. +The meal was cooked before the bell rang. +What joy there is in living. +A king ruled the state in the early days. +The ship was torn apart on the sharp reef. +Sickness kept him home the third week. +The wide road shimmered in the hot sun. +The lazy cow lay in the cool grass. +Lift the square stone over the fence. +The rope will bind the seven books at once. +Hop over the fence and plunge in. +The friendly gang left the drug store. +Mesh wire keeps chicks inside. +The frosty air passed through the coat. +The crooked maze failed to fool the mouse. +Adding fast leads to wrong sums. +The show was a flop from the very start. +A saw is a tool used for making boards. +The wagon moved on well oiled wheels. +March the soldiers past the next hill. +A cup of sugar makes sweet fudge. +Place a rosebush near the porch steps. +Both lost their lives in the raging storm. +We talked of the side show in the circus. +Use a pencil to write the first draft. +He ran half way to the hardware store. +The clock struck to mark the third period. +A small creek cut across the field. +Cars and busses stalled in snow drifts. +The set of china hit the floor with a crash. +This is a grand season for hikes on the road. +The dune rose from the edge of the water. +Those words were the cue for the actor to leave. +A yacht slid around the point into the bay. +The two met while playing on the sand. +The ink stain dried on the finished page. +The walled town was seized without a fight. +The lease ran out in sixteen weeks. +A tame squirrel makes a nice pet. +The horn of the car woke the sleeping cop. +The heart beat strongly and with firm strokes. +The pearl was worn in a thin silver ring. +The fruit peel was cut in thick slices. +The Navy attacked the big task force. +See the cat glaring at the scared mouse. +There are more than two factors here. +The hat brim was wide and too droopy. +The lawyer tried to lose his case. +The grass curled around the fence post. +Cut the pie into large parts. +Men strive but seldom get rich. +Always close the barn door tight. +He lay prone and hardly moved a limb. +The slush lay deep along the street. +A wisp of cloud hung in the blue air. +A pound of sugar costs more than eggs. +The fin was sharp and cut the clear water. +The play seems dull and quite stupid. +Bail the boat to stop it from sinking. +The term ended in late June that year. +A tusk is used to make costly gifts. +Ten pins were set in order. +The bill was paid every third week. +Oak is strong and also gives shade. +Cats and dogs each hate the other. +The pipe began to rust while new. +Open the crate but don't break the glass. +Add the sum to the product of these three. +Thieves who rob friends deserve jail. +The ripe taste of cheese improves with age. +Act on these orders with great speed. +The hog crawled under the high fence. +Move the vat over the hot fire. +The bark of the pine tree was shiny and dark. +Leaves turn brown and yellow in the fall. +The pennant waved when the wind blew. +Split the log with a quick, sharp blow. +Burn peat after the logs give out. +He ordered peach pie with ice cream. +Weave the carpet on the right hand side. +Hemp is a weed found in parts of the tropics. +A lame back kept his score low. +We find joy in the simplest things. +Type out three lists of orders. +The harder he tried the less he got done. +The boss ran the show with a watchful eye. +The cup cracked and spilled its contents. +Paste can cleanse the most dirty brass. +The slang word for raw whiskey is booze. +It caught its hind paw in a rusty trap. +The wharf could be seen at the farther shore. +Feel the heat of the weak dying flame. +The tiny girl took off her hat. +A cramp is no small danger on a swim. +He said the same phrase thirty times. +Pluck the bright rose without leaves. +Two plus seven is less than ten. +The glow deepened in the eyes of the sweet girl. +Bring your problems to the wise chief. +Write a fond note to the friend you cherish. +Clothes and lodging are free to new men. +We frown when events take a bad turn. +Port is a strong wine with a smoky taste. +The young kid jumped the rusty gate. +Guess the results from the first scores. +A salt pickle tastes fine with ham. +The just claim got the right verdict. +These thistles bend in a high wind. +Pure bred poodles have curls. +The tree top waved in a graceful way. +The spot on the blotter was made by green ink. +Mud was spattered on the front of his white shirt. +The cigar burned a hole in the desk top. +The empty flask stood on the tin tray. +A speedy man can beat this track mark. +He broke a new shoelace that day. +The coffee stand is too high for the couch. +The urge to write short stories is rare. +The pencils have all been used. +The pirates seized the crew of the lost ship. +We tried to replace the coin but failed. +She sewed the torn coat quite neatly. +The sofa cushion is red and of light weight. +The jacket hung on the back of the wide chair. +At that high level the air is pure. +Drop the two when you add the figures. +A filing case is now hard to buy. +An abrupt start does not win the prize. +Wood is best for making toys and blocks. +The office paint was a dull, sad tan. +He knew the skill of the great young actress. +A rag will soak up spilled water. +A shower of dirt fell from the hot pipes. +Steam hissed from the broken valve. +The child almost hurt the small dog. +There was a sound of dry leaves outside. +The sky that morning was clear and bright blue. +Torn scraps littered the stone floor. +Sunday is the best part of the week. +The doctor cured him with these pills. +The new girl was fired today at noon. +They felt gay when the ship arrived in port. +Add the store's account to the last cent. +Acid burns holes in wool cloth. +Fairy tales should be fun to write. +Eight miles of woodland burned to waste. +The third act was dull and tired the players. +A young child should not suffer fright. +Add the column and put the sum here. +We admire and love a good cook. +There the flood mark is ten inches. +He carved a head from the round block of marble. +She has a smart way of wearing clothes. +The fruit of a fig tree is apple-shaped. +Corn cobs can be used to kindle a fire. +Where were they when the noise started. +The paper box is full of thumb tacks. +Sell your gift to a buyer at a good gain. +The tongs lay beside the ice pail. +The petals fall with the next puff of wind. +Bring your best compass to the third class. +They could laugh although they were sad. +Farmers came in to thresh the oat crop. +The brown house was on fire to the attic. +The lure is used to catch trout and flounder. +Float the soap on top of the bath water. +A blue crane is a tall wading bird. +A fresh start will work such wonders. +The club rented the rink for the fifth night. +After the dance, they went straight home. +The hostess taught the new maid to serve. +He wrote his last novel there at the inn. +Even the worst will beat his low score. +The cement had dried when he moved it. +The loss of the second ship was hard to take. +The fly made its way along the wall. +Do that with a wooden stick. +Live wires should be kept covered. +The large house had hot water taps. +It is hard to erase blue or red ink. +Write at once or you may forget it. +The doorknob was made of bright clean brass. +The wreck occurred by the bank on Main Street. +A pencil with black lead writes best. +Coax a young calf to drink from a bucket. +Schools for ladies teach charm and grace. +The lamp shone with a steady green flame. +They took the axe and the saw to the forest. +The ancient coin was quite dull and worn. +The shaky barn fell with a loud crash. +Jazz and swing fans like fast music. +Rake the rubbish up and then burn it. +Slash the gold cloth into fine ribbons. +Try to have the court decide the case. +They are pushed back each time they attack. +He broke his ties with groups of former friends. +They floated on the raft to sun their white backs. +The map had an X that meant nothing. +Whitings are small fish caught in nets. +Some ads serve to cheat buyers. +Jerk the rope and the bell rings weakly. +A waxed floor makes us lose balance. +Madam, this is the best brand of corn. +On the islands the sea breeze is soft and mild. +The play began as soon as we sat down. +This will lead the world to more sound and fury. +Add salt before you fry the egg. +The rush for funds reached its peak Tuesday. +The birch looked stark white and lonesome. +The box is held by a bright red snapper. +To make pure ice, you freeze water. +The first worm gets snapped early. +Jump the fence and hurry up the bank. +Yell and clap as the curtain slides back. +They are men who walk the middle of the road. +Both brothers wear the same size. +In some form or other we need fun. +The prince ordered his head chopped off. +The houses are built of red clay bricks. +Ducks fly north but lack a compass. +Fruit flavors are used in fizz drinks. +These pills do less good than others. +Canned pears lack full flavor. +The dark pot hung in the front closet. +Carry the pail to the wall and spill it there. +The train brought our hero to the big town. +We are sure that one war is enough. +Gray paint stretched for miles around. +The rude laugh filled the empty room. +High seats are best for football fans. +Tea served from the brown jug is tasty. +A dash of pepper spoils beef stew. +A zestful food is the hot-cross bun. +The horse trotted around the field at a brisk pace. +Find the twin who stole the pearl necklace. +Cut the cord that binds the box tightly. +The red tape bound the smuggled food. +Look in the corner to find the tan shirt. +The cold drizzle will halt the bond drive. +Nine men were hired to dig the ruins. +The junk yard had a mouldy smell. +The flint sputtered and lit a pine torch. +Soak the cloth and drown the sharp odor. +The shelves were bare of both jam or crackers. +A joy to every child is the swan boat. +All sat frozen and watched the screen. +A cloud of dust stung his tender eyes. +To reach the end he needs much courage. +Shape the clay gently into block form. +A ridge on a smooth surface is a bump or flaw. +Hedge apples may stain your hands green. +Quench your thirst, then eat the crackers. +Tight curls get limp on rainy days. +The mute muffled the high tones of the horn. +The gold ring fits only a pierced ear. +The old pan was covered with hard fudge. +Watch the log float in the wide river. +The node on the stalk of wheat grew daily. +The heap of fallen leaves was set on fire. +Write fast if you want to finish early. +His shirt was clean but one button was gone. +The barrel of beer was a brew of malt and hops. +Tin cans are absent from store shelves. +Slide the box into that empty space. +The plant grew large and green in the window. +The beam dropped down on the workmen's head. +Pink clouds floated with the breeze. +She danced like a swan, tall and graceful. +The tube was blown and the tire flat and useless. +It is late morning on the old wall clock. +Let's all join as we sing the last chorus. +The last switch cannot be turned off. +The fight will end in just six minutes. +The store walls were lined with colored frocks. +The peace league met to discuss their plans. +The rise to fame of a person takes luck. +Paper is scarce, so write with much care. +The quick fox jumped on the sleeping cat. +The nozzle of the fire hose was bright brass. +Screw the round cap on as tight as needed. +Time brings us many changes. +The purple tie was ten years old. +Men think and plan and sometimes act. +Fill the ink jar with sticky glue. +He smoke a big pipe with strong contents. +We need grain to keep our mules healthy. +Pack the records in a neat thin case. +The crunch of feet in the snow was the only sound. +The copper bowl shone in the sun's rays. +Boards will warp unless kept dry. +The plush chair leaned against the wall. +Glass will clink when struck by metal. +Bathe and relax in the cool green grass. +Nine rows of soldiers stood in line. +The beach is dry and shallow at low tide. +The idea is to sew both edges straight. +The kitten chased the dog down the street. +Pages bound in cloth make a book. +Try to trace the fine lines of the painting. +Women form less than half of the group. +The zones merge in the central part of town. +A gem in the rough needs work to polish. +Code is used when secrets are sent. +Most of the news is easy for us to hear. +He used the lathe to make brass objects. +The vane on top of the pole revolved in the wind. +Mince pie is a dish served to children. +The clan gathered on each dull night. +Let it burn, it gives us warmth and comfort. +A castle built from sand fails to endure. +A child's wit saved the day for us. +Tack the strip of carpet to the worn floor. +Next Tuesday we must vote. +Pour the stew from the pot into the plate. +Each penny shone like new. +The man went to the woods to gather sticks. +The dirt piles were lines along the road. +The logs fell and tumbled into the clear stream. +Just hoist it up and take it away. +A ripe plum is fit for a king's palate. +Our plans right now are hazy. +Brass rings are sold by these natives. +It takes a good trap to capture a bear. +Feed the white mouse some flower seeds. +The thaw came early and freed the stream. +He took the lead and kept it the whole distance. +The key you designed will fit the lock. +Plead to the council to free the poor thief. +Better hash is made of rare beef. +This plank was made for walking on. +The lake sparkled in the red hot sun. +He crawled with care along the ledge. +Tend the sheep while the dog wanders. +It takes a lot of help to finish these. +Mark the spot with a sign painted red. +Take two shares as a fair profit. +The fur of cats goes by many names. +North winds bring colds and fevers. +He asks no person to vouch for him. +Go now and come here later. +A sash of gold silk will trim her dress. +Soap can wash most dirt away. +That move means the game is over. +He wrote down a long list of items. +A siege will crack the strong defense. +Grape juice and water mix well. +Roads are paved with sticky tar. +Fake stones shine but cost little. +The drip of the rain made a pleasant sound. +Smoke poured out of every crack. +Serve the hot rum to the tired heroes. +Much of the story makes good sense. +The sun came up to light the eastern sky. +Heave the line over the port side. +A lathe cuts and trims any wood. +It's a dense crowd in two distinct ways. +His hip struck the knee of the next player. +The stale smell of old beer lingers. +The desk was firm on the shaky floor. +It takes heat to bring out the odor. +Beef is scarcer than some lamb. +Raise the sail and steer the ship northward. +A cone costs five cents on Mondays. +A pod is what peas always grow in. +Jerk the dart from the cork target. +No cement will hold hard wood. +We now have a new base for shipping. +A list of names is carved around the base. +The sheep were led home by a dog. +Three for a dime, the young peddler cried. +The sense of smell is better than that of touch. +No hardship seemed to keep him sad. +Grace makes up for lack of beauty. +Nudge gently but wake her now. +The news struck doubt into restless minds. +Once we stood beside the shore. +A chink in the wall allowed a draft to blow. +Fasten two pins on each side. +A cold dip restores health and zest. +He takes the oath of office each March. +The sand drifts over the sill of the old house. +The point of the steel pen was bent and twisted. +There is a lag between thought and act. +Seed is needed to plant the spring corn. +Draw the chart with heavy black lines. +The boy owed his pal thirty cents. +The chap slipped into the crowd and was lost. +Hats are worn to tea and not to dinner. +The ramp led up to the wide highway. +Beat the dust from the rug onto the lawn. +Say it slowly but make it ring clear. +The straw nest housed five robins. +Screen the porch with woven straw mats. +This horse will nose his way to the finish. +The dry wax protects the deep scratch. +He picked up the dice for a second roll. +These coins will be needed to pay his debt. +The nag pulled the frail cart along. +Twist the valve and release hot steam. +The vamp of the shoe had a gold buckle. +The smell of burned rags itches my nose. +New pants lack cuffs and pockets. +The marsh will freeze when cold enough. +They slice the sausage thin with a knife. +The bloom of the rose lasts a few days. +A gray mare walked before the colt. +Breakfast buns are fine with a hot drink. +Bottles hold four kinds of rum. +The man wore a feather in his felt hat. +He wheeled the bike past the winding road. +Drop the ashes on the worn old rug. +The desk and both chairs were painted tan. +Throw out the used paper cup and plate. +A clean neck means a neat collar. +The couch cover and hall drapes were blue. +The stems of the tall glasses cracked and broke. +The wall phone rang loud and often. +The clothes dried on a thin wooden rack. +Turn on the lantern which gives us light. +The cleat sank deeply into the soft turf. +The bills were mailed promptly on the tenth of the month. +To have is better than to wait and hope. +The price is fair for a good antique clock. +The music played on while they talked. +Dispense with a vest on a day like this. +The bunch of grapes was pressed into wine. +He sent the figs, but kept the ripe cherries. +The hinge on the door creaked with old age. +The screen before the fire kept in the sparks. +Fly by night, and you waste little time. +Thick glasses helped him read the print. +Birth and death mark the limits of life. +The chair looked strong but had no bottom. +The kite flew wildly in the high wind. +A fur muff is stylish once more. +The tin box held priceless stones. +We need an end of all such matter. +The case was puzzling to the old and wise. +The bright lanterns were gay on the dark lawn. +We don't get much money but we have fun. +The youth drove with zest, but little skill. +Five years he lived with a shaggy dog. +A fence cuts through the corner lot. +The way to save money is not to spend much. +Shut the hatch before the waves push it in. +The odor of spring makes young hearts jump. +Crack the walnut with your sharp side teeth. +He offered proof in the form of a large chart. +Send the stuff in a thick paper bag. +A quart of milk is water for the most part. +They told wild tales to frighten him. +The three story house was built of stone. +In the rear of the ground floor was a large passage. +A man in a blue sweater sat at the desk. +Oats are a food eaten by horse and man. +Their eyelids droop for want of sleep. +A sip of tea revives his tired friend. +There are many ways to do these things. +Tuck the sheet under the edge of the mat. +A force equal to that would move the earth. +We like to see clear weather. +The work of the tailor is seen on each side. +Take a chance and win a china doll. +Shake the dust from your shoes, stranger. +She was kind to sick old people. +The square wooden crate was packed to be shipped. +The dusty bench stood by the stone wall. +We dress to suit the weather of most days. +Smile when you say nasty words. +A bowl of rice is free with chicken stew. +The water in this well is a source of good health. +Take shelter in this tent, but keep still. +That guy is the writer of a few banned books. +The little tales they tell are false. +The door was barred, locked, and bolted as well. +Ripe pears are fit for a queen's table. +A big wet stain was on the round carpet. +The kite dipped and swayed, but stayed aloft. +The pleasant hours fly by much too soon. +The room was crowded with a wild mob. +This strong arm shall shield your honor. +She blushed when he gave her a white orchid. +The beetle droned in the hot June sun. +Press the pedal with your left foot. +Neat plans fail without luck. +The black trunk fell from the landing. +The bank pressed for payment of the debt. +The theft of the pearl pin was kept secret. +Shake hands with this friendly child. +The vast space stretched into the far distance. +A rich farm is rare in this sandy waste. +His wide grin earned many friends. +Flax makes a fine brand of paper. +Hurdle the pit with the aid of a long pole. +A strong bid may scare your partner stiff. +Even a just cause needs power to win. +Peep under the tent and see the clowns. +The leaf drifts along with a slow spin. +Cheap clothes are flashy but don't last. +A thing of small note can cause despair. +Flood the mails with requests for this book. +A thick coat of black paint covered all. +The pencil was cut to be sharp at both ends. +Those last words were a strong statement. +He wrote his name boldly at the top of the sheet. +Dill pickles are sour but taste fine. +Down that road is the way to the grain farmer. +Either mud or dust are found at all times. +The best method is to fix it in place with clips. +If you mumble your speech will be lost. +At night the alarm roused him from a deep sleep. +Read just what the meter says. +Fill your pack with bright trinkets for the poor. +The small red neon lamp went out. +Clams are small, round, soft, and tasty. +The fan whirled its round blades softly. +The line where the edges join was clean. +Breathe deep and smell the piny air. +It matters not if he reads these words or those. +A brown leather bag hung from its strap. +A toad and a frog are hard to tell apart. +A white silk jacket goes with any shoes. +A break in the dam almost caused a flood. +Paint the sockets in the wall dull green. +The child crawled into the dense grass. +Bribes fail where honest men work. +Trample the spark, else the flames will spread. +The hilt of the sword was carved with fine designs. +A round hole was drilled through the thin board. +Footprints showed the path he took up the beach. +She was waiting at my front lawn. +A vent near the edge brought in fresh air. +Prod the old mule with a crooked stick. +It is a band of steel three inches wide. +The pipe ran almost the length of the ditch. +It was hidden from sight by a mass of leaves and shrubs. +The weight of the package was seen on the high scale. +Wake and rise, and step into the green outdoors. +The green light in the brown box flickered. +The brass tube circled the high wall. +The lobes of her ears were pierced to hold rings. +Hold the hammer near the end to drive the nail. +Next Sunday is the twelfth of the month. +Every word and phrase he speaks is true. +He put his last cartridge into the gun and fired. +They took their kids from the public school. +Drive the screw straight into the wood. +Keep the hatch tight and the watch constant. +Sever the twine with a quick snip of the knife. +Paper will dry out when wet. +Slide the catch back and open the desk. +Help the weak to preserve their strength. +A sullen smile gets few friends. +Stop whistling and watch the boys march. +Jerk the cord, and out tumbles the gold. +Slide the tray across the glass top. +The cloud moved in a stately way and was gone. +Light maple makes for a swell room. +Set the piece here and say nothing. +Dull stories make her laugh. +A stiff cord will do to fasten your shoe. +Get the trust fund to the bank early. +Choose between the high road and the low. +A plea for funds seems to come again. +He lent his coat to the tall gaunt stranger. +There is a strong chance it will happen once more. +The duke left the park in a silver coach. +Greet the new guests and leave quickly. +When the frost has come it is time for turkey. +Sweet words work better than fierce. +A thin stripe runs down the middle. +A six comes up more often than a ten. +Lush fern grow on the lofty rocks. +The ram scared the school children off. +The team with the best timing looks good. +The farmer swapped his horse for a brown ox. +Sit on the perch and tell the others what to do. +A steep trail is painful for our feet. +The early phase of life moves fast. +Green moss grows on the northern side. +Tea in thin china has a sweet taste. +Pitch the straw through the door of the stable. +The latch on the back gate needed a nail. +The goose was brought straight from the old market. +The sink is the thing in which we pile dishes. +A whiff of it will cure the most stubborn cold. +The facts don't always show who is right. +She flaps her cape as she parades the street. +The loss of the cruiser was a blow to the fleet. +Loop the braid to the left and then over. +Plead with the lawyer to drop the lost cause. +Calves thrive on tender spring grass. +Post no bills on this office wall. +Tear a thin sheet from the yellow pad. +A cruise in warm waters in a sleek yacht is fun. +A streak of color ran down the left edge. +It was done before the boy could see it. +Crouch before you jump or miss the mark. +Pack the kits and don't forget the salt. +The square peg will settle in the round hole. +Fine soap saves tender skin. +Poached eggs and tea must suffice. +Bad nerves are jangled by a door slam. +Ship maps are different from those for planes. +Dimes showered down from all sides. +They sang the same tunes at each party. +The sky in the west is tinged with orange red. +The pods of peas ferment in bare fields. +The horse balked and threw the tall rider. +The hitch between the horse and cart broke. +Pile the coal high in the shed corner. +A gold vase is both rare and costly. +The knife was hung inside its bright sheath. +The rarest spice comes from the far East. +The roof should be tilted at a sharp slant. +A smatter of French is worse than none. +The mule trod the treadmill day and night. +The aim of the contest is to raise a great fund. +To send it now in large amounts is bad. +There is a fine hard tang in salty air. +Cod is the main business of the north shore. +The slab was hewn from heavy blocks of slate. +Dunk the stale biscuits into strong drink. +Hang tinsel from both branches. +Cap the jar with a tight brass cover. +The poor boy missed the boat again. +Be sure to set the lamp firmly in the hole. +Pick a card and slip it under the pack. +A round mat will cover the dull spot. +The first part of the plan needs changing. +A good book informs of what we ought to know. +The mail comes in three batches per day. +You cannot brew tea in a cold pot. +Dots of light betrayed the black cat. +Put the chart on the mantel and tack it down. +The night shift men rate extra pay. +The red paper brightened the dim stage. +See the player scoot to third base. +Slide the bill between the two leaves. +Many hands help get the job done. +We don't like to admit our small faults. +No doubt about the way the wind blows. +Dig deep in the earth for pirate's gold. +The steady drip is worse than a drenching rain. +A flat pack takes less luggage space. +Green ice frosted the punch bowl. +A stuffed chair slipped from the moving van. +The stitch will serve but needs to be shortened. +A thin book fits in the side pocket. +The gloss on top made it unfit to read. +The hail pattered on the burnt brown grass. +Seven seals were stamped on great sheets. +Our troops are set to strike heavy blows. +The store was jammed before the sale could start. +It was a bad error on the part of the new judge. +One step more and the board will collapse. +Take the match and strike it against your shoe. +The pot boiled, but the contents failed to jell. +The baby puts his right foot in his mouth. +The bombs left most of the town in ruins. +Stop and stare at the hard working man. +The streets are narrow and full of sharp turns. +The pup jerked the leash as he saw a feline shape. +Open your book to the first page. +Fish evade the net and swim off. +Dip the pail once and let it settle. +Will you please answer that phone. +The big red apple fell to the ground. +The curtain rose and the show was on. +The young prince became heir to the throne. +He sent the boy on a short errand. +Leave now and you will arrive on time. +The corner store was robbed last night. +A gold ring will please most any girl. +The long journey home took a year. +She saw a cat in the neighbor's house. +A pink shell was found on the sandy beach. +Small children came to see him. +The grass and bushes were wet with dew. +The blind man counted his old coins. +A severe storm tore down the barn. +She called his name many times. +When you hear the bell, come quickly. \ No newline at end of file diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py index 3ecd5bd..453207a 100644 --- a/extensions/silero_tts/script.py +++ b/extensions/silero_tts/script.py @@ -1,3 +1,5 @@ +import html +import random import time from pathlib import Path @@ -5,7 +7,7 @@ import gradio as gr import torch from extensions.silero_tts import tts_preprocessor -from modules import chat, shared +from modules import chat, shared, ui_chat from modules.utils import gradio torch._C._jit_set_profiling_mode(False) @@ -26,7 +28,25 @@ params = { } current_params = params.copy() -voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115'] + +voices_en = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115'] +voices_es = ["es_0", "es_1", "es_2"] +voices_fr = ["fr_0", "fr_1", "fr_2", "fr_3", "fr_4", "fr_5"] +voices_de = ["bernd_ungerer", "eva_k", "friedrich", "hokuspokus", "karlsson"] +voices_ru = ["aidar", "baya", "kseniya", "xenia"] +voices_ua = ["mykyta"] +voices_uz = ["dilnavoz"] + +languages = { + "en": {"label": "English", "voices": voices_en, "default_voice": "en_56", "model_id": "v3_en"}, + "es": {"label": "Español", "voices": voices_es, "default_voice": "es_0", "model_id": "v3_es"}, + "fr": {"label": "Français", "voices": voices_fr, "default_voice": "fr_0", "model_id": "v3_fr"}, + "de": {"label": "Deutsch", "voices": voices_de, "default_voice": "eva_k", "model_id": "v3_de"}, + "ru": {"label": "русский", "voices": voices_ru, "default_voice": "aidar", "model_id": "ru_v3"}, + "ua": {"label": "українська", "voices": voices_ua, "default_voice": "mykyta", "model_id": "v3_ua"}, + "uz": {"label": "Oʻzbekcha", "voices": voices_uz, "default_voice": "dilnavoz", "model_id": "v3_uz"}, +} + voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high'] voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast'] @@ -106,6 +126,7 @@ def history_modifier(history): def output_modifier(string, state): global model, current_params, streaming_state + for i in params: if params[i] != current_params[i]: model = load_model() @@ -116,7 +137,7 @@ def output_modifier(string, state): return string original_string = string - string = tts_preprocessor.preprocess(string) + string = tts_preprocessor.preprocess(html.unescape(string)) if string == '': string = '*Empty reply, try regenerating*' @@ -140,6 +161,42 @@ def setup(): model = load_model() +def random_sentence(): + with open(Path("extensions/silero_tts/harvard_sentences.txt")) as f: + return random.choice(list(f)) + + +def voice_preview(preview_text): + global model, current_params, streaming_state + + for i in params: + if params[i] != current_params[i]: + model = load_model() + current_params = params.copy() + break + + string = tts_preprocessor.preprocess(preview_text or random_sentence()) + + output_file = Path('extensions/silero_tts/outputs/voice_preview.wav') + prosody = f"" + silero_input = f'{prosody}{xmlesc(string)}' + model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file)) + + return f'' + + +def language_change(lang): + global params + lang_code = list(languages.keys())[lang] + params.update({"language": lang_code, "speaker": languages[lang_code]["default_voice"], "model_id": languages[lang_code]["model_id"]}) + return gr.update(choices=languages[lang_code]["voices"], value=languages[lang_code]["default_voice"]) + + +def custom_css(): + path_to_css = Path(__file__).parent.resolve() / 'style.css' + return open(path_to_css, 'r').read() + + def ui(): # Gradio elements with gr.Accordion("Silero TTS"): @@ -148,40 +205,50 @@ def ui(): autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically') show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player') - voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice') + + with gr.Row(): + language = gr.Dropdown(value=languages[params['language']]["label"], choices=[v["label"] for _, v in languages.items()], label='Language', type="index") + voice = gr.Dropdown(value=params['speaker'], choices=voices_en, label='TTS voice') with gr.Row(): v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch') v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed') + with gr.Row(): + preview_text = gr.Text(show_label=False, placeholder="Preview text", elem_id="silero_preview_text") + preview_play = gr.Button("Preview") + preview_audio = gr.HTML(visible=False) + with gr.Row(): convert = gr.Button('Permanently replace audios with the message texts') convert_cancel = gr.Button('Cancel', visible=False) convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False) - gr.Markdown('[Click here for Silero audio samples](https://oobabooga.github.io/silero-samples/index.html)') + # Convert history with confirmation + convert_arr = [convert_confirm, convert, convert_cancel] + convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) + convert_confirm.click( + lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then( + remove_tts_from_history, gradio('history'), gradio('history')).then( + chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then( + chat.redraw_html, gradio(ui_chat.reload_arr), gradio('display')) - if shared.is_chat(): - # Convert history with confirmation - convert_arr = [convert_confirm, convert, convert_cancel] - convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) - convert_confirm.click( - lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then( - remove_tts_from_history, gradio('history'), gradio('history')).then( - chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then( - chat.redraw_html, shared.reload_inputs, gradio('display')) + convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) - convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) - - # Toggle message text in history - show_text.change( - lambda x: params.update({"show_text": x}), show_text, None).then( - toggle_text_in_history, gradio('history'), gradio('history')).then( - chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then( - chat.redraw_html, shared.reload_inputs, gradio('display')) + # Toggle message text in history + show_text.change( + lambda x: params.update({"show_text": x}), show_text, None).then( + toggle_text_in_history, gradio('history'), gradio('history')).then( + chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then( + chat.redraw_html, gradio(ui_chat.reload_arr), gradio('display')) # Event functions to update the parameters in the backend activate.change(lambda x: params.update({"activate": x}), activate, None) autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None) + language.change(language_change, language, voice, show_progress=False) voice.change(lambda x: params.update({"speaker": x}), voice, None) v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None) v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None) + + # Play preview + preview_text.submit(voice_preview, preview_text, preview_audio) + preview_play.click(voice_preview, preview_text, preview_audio) diff --git a/extensions/silero_tts/style.css b/extensions/silero_tts/style.css new file mode 100644 index 0000000..2ab7aef --- /dev/null +++ b/extensions/silero_tts/style.css @@ -0,0 +1,8 @@ +.SDAP .hires_opts input[type="number"] { + width: 6em !important; +} + +/* silero_tts preview */ +.form:has(> #silero_preview_text) { + min-width: 75% +} diff --git a/extensions/superbooga/download_urls.py b/extensions/superbooga/download_urls.py index efe300d..424a988 100644 --- a/extensions/superbooga/download_urls.py +++ b/extensions/superbooga/download_urls.py @@ -4,7 +4,10 @@ import requests def download_single(url): - response = requests.get(url, timeout=5) + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' + } + response = requests.get(url, headers=headers, timeout=5) if response.status_code == 200: return response.content else: diff --git a/extensions/superbooga/requirements.txt b/extensions/superbooga/requirements.txt index dd2cbde..73a6007 100644 --- a/extensions/superbooga/requirements.txt +++ b/extensions/superbooga/requirements.txt @@ -1,4 +1,6 @@ beautifulsoup4==4.12.2 chromadb==0.3.18 +pandas==2.0.3 posthog==2.4.2 sentence_transformers==2.2.2 +lxml diff --git a/extensions/superbooga/script.py b/extensions/superbooga/script.py index c0d3f8e..06fe8ad 100644 --- a/extensions/superbooga/script.py +++ b/extensions/superbooga/script.py @@ -4,7 +4,7 @@ import textwrap import gradio as gr from bs4 import BeautifulSoup -from modules import chat, shared +from modules import chat from modules.logging_colors import logger from .chromadb import add_chunks_to_collector, make_collector @@ -69,7 +69,7 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads) cumulative += 'Processing the HTML sources...' yield cumulative for content in contents: - soup = BeautifulSoup(content, features="html.parser") + soup = BeautifulSoup(content, features="lxml") for script in soup(["script", "style"]): script.extract() @@ -96,7 +96,8 @@ def apply_settings(chunk_count, chunk_count_initial, time_weight): def custom_generate_chat_prompt(user_input, state, **kwargs): global chat_collector - history = state['history'] + # get history as being modified when using regenerate. + history = kwargs['history'] if state['mode'] == 'instruct': results = collector.get_sorted(user_input, n_results=params['chunk_count']) @@ -113,7 +114,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs): if len(history['internal']) > params['chunk_count'] and user_input != '': chunks = [] hist_size = len(history['internal']) - for i in range(hist_size-1): + for i in range(hist_size - 1): chunks.append(make_single_exchange(i)) add_chunks_to_collector(chunks, chat_collector) @@ -142,8 +143,8 @@ def remove_special_tokens(string): return re.sub(pattern, '', string) -def input_modifier(string): - if shared.is_chat(): +def input_modifier(string, state, is_chat=False): + if is_chat: return string # Find the user input diff --git a/extensions/whisper_stt/script.py b/extensions/whisper_stt/script.py index 44a9ac8..cdc5568 100644 --- a/extensions/whisper_stt/script.py +++ b/extensions/whisper_stt/script.py @@ -16,7 +16,16 @@ params = { } -def do_stt(audio,whipser_model,whipser_language): +def chat_input_modifier(text, visible_text, state): + global input_hijack + if input_hijack['state']: + input_hijack['state'] = False + return input_hijack['value'] + else: + return text, visible_text + + +def do_stt(audio, whipser_model, whipser_language): transcription = "" r = sr.Recognizer() @@ -33,10 +42,10 @@ def do_stt(audio,whipser_model,whipser_language): return transcription -def auto_transcribe(audio, auto_submit,whipser_model,whipser_language): +def auto_transcribe(audio, auto_submit, whipser_model, whipser_language): if audio is None: return "", "" - transcription = do_stt(audio,whipser_model,whipser_language) + transcription = do_stt(audio, whipser_model, whipser_language) if auto_submit: input_hijack.update({"state": True, "value": [transcription, transcription]}) @@ -50,12 +59,13 @@ def ui(): with gr.Row(): with gr.Accordion("Settings", open=False): auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=params['auto_submit']) - whipser_model = gr.Dropdown(label='Whisper Model', value=params['whipser_model'],choices=["tiny.en","base.en", "small.en","medium.en","tiny","base","small","medium","large"]) - whipser_language = gr.Dropdown(label='Whisper Language', value=params['whipser_language'],choices=["chinese","german","spanish","russian","korean","french","japanese","portuguese","turkish","polish","catalan","dutch","arabic","swedish","italian","indonesian","hindi","finnish","vietnamese","hebrew","ukrainian","greek","malay","czech","romanian","danish","hungarian","tamil","norwegian","thai","urdu","croatian","bulgarian","lithuanian","latin","maori","malayalam","welsh","slovak","telugu","persian","latvian","bengali","serbian","azerbaijani","slovenian","kannada","estonian","macedonian","breton","basque","icelandic","armenian","nepali","mongolian","bosnian","kazakh","albanian","swahili","galician","marathi","punjabi","sinhala","khmer","shona","yoruba","somali","afrikaans","occitan","georgian","belarusian","tajik","sindhi","gujarati","amharic","yiddish","lao","uzbek","faroese","haitian creole","pashto","turkmen","nynorsk","maltese","sanskrit","luxembourgish","myanmar","tibetan","tagalog","malagasy","assamese","tatar","hawaiian","lingala","hausa","bashkir","javanese","sundanese"]) + whipser_model = gr.Dropdown(label='Whisper Model', value=params['whipser_model'], choices=["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large"]) + whipser_language = gr.Dropdown(label='Whisper Language', value=params['whipser_language'], choices=["chinese", "german", "spanish", "russian", "korean", "french", "japanese", "portuguese", "turkish", "polish", "catalan", "dutch", "arabic", "swedish", "italian", "indonesian", "hindi", "finnish", "vietnamese", "hebrew", "ukrainian", "greek", "malay", "czech", "romanian", "danish", "hungarian", "tamil", "norwegian", "thai", "urdu", "croatian", "bulgarian", "lithuanian", "latin", "maori", "malayalam", "welsh", "slovak", "telugu", "persian", "latvian", "bengali", "serbian", "azerbaijani", "slovenian", "kannada", "estonian", "macedonian", "breton", "basque", "icelandic", "armenian", "nepali", "mongolian", "bosnian", "kazakh", "albanian", "swahili", "galician", "marathi", "punjabi", "sinhala", "khmer", "shona", "yoruba", "somali", "afrikaans", "occitan", "georgian", "belarusian", "tajik", "sindhi", "gujarati", "amharic", "yiddish", "lao", "uzbek", "faroese", "haitian creole", "pashto", "turkmen", "nynorsk", "maltese", "sanskrit", "luxembourgish", "myanmar", "tibetan", "tagalog", "malagasy", "assamese", "tatar", "hawaiian", "lingala", "hausa", "bashkir", "javanese", "sundanese"]) audio.change( - auto_transcribe, [audio, auto_submit,whipser_model,whipser_language], [shared.gradio['textbox'], audio]).then( + auto_transcribe, [audio, auto_submit, whipser_model, whipser_language], [shared.gradio['textbox'], audio]).then( None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}") + whipser_model.change(lambda x: params.update({"whipser_model": x}), whipser_model, None) whipser_language.change(lambda x: params.update({"whipser_language": x}), whipser_language, None) auto_submit.change(lambda x: params.update({"auto_submit": x}), auto_submit, None) diff --git a/instruction-templates/Airoboros-v1.2.yaml b/instruction-templates/Airoboros-v1.2.yaml new file mode 100644 index 0000000..7f1bfed --- /dev/null +++ b/instruction-templates/Airoboros-v1.2.yaml @@ -0,0 +1,4 @@ +user: "USER:" +bot: "ASSISTANT:" +turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|>\n" +context: "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user's input.\n" diff --git a/characters/instruction-following/Alpaca.yaml b/instruction-templates/Alpaca.yaml similarity index 100% rename from characters/instruction-following/Alpaca.yaml rename to instruction-templates/Alpaca.yaml diff --git a/characters/instruction-following/Bactrian.yaml b/instruction-templates/Bactrian.yaml similarity index 100% rename from characters/instruction-following/Bactrian.yaml rename to instruction-templates/Bactrian.yaml diff --git a/instruction-templates/Baichuan Chat.yaml b/instruction-templates/Baichuan Chat.yaml new file mode 100644 index 0000000..15adca1 --- /dev/null +++ b/instruction-templates/Baichuan Chat.yaml @@ -0,0 +1,4 @@ +user: "" +bot: "" +turn_template: "<|user|><|user-message|><|bot|><|bot-message|>" +context: "" diff --git a/characters/instruction-following/Baize.yaml b/instruction-templates/Baize.yaml similarity index 100% rename from characters/instruction-following/Baize.yaml rename to instruction-templates/Baize.yaml diff --git a/characters/instruction-following/Bluemoon.yaml b/instruction-templates/Bluemoon.yaml similarity index 100% rename from characters/instruction-following/Bluemoon.yaml rename to instruction-templates/Bluemoon.yaml diff --git a/characters/instruction-following/ChatGLM.yaml b/instruction-templates/ChatGLM.yaml similarity index 100% rename from characters/instruction-following/ChatGLM.yaml rename to instruction-templates/ChatGLM.yaml diff --git a/characters/instruction-following/Chinese-Vicuna-Chat.yaml b/instruction-templates/Chinese-Vicuna-Chat.yaml similarity index 100% rename from characters/instruction-following/Chinese-Vicuna-Chat.yaml rename to instruction-templates/Chinese-Vicuna-Chat.yaml diff --git a/characters/instruction-following/Galactica Cite.yaml b/instruction-templates/Galactica Cite.yaml similarity index 100% rename from characters/instruction-following/Galactica Cite.yaml rename to instruction-templates/Galactica Cite.yaml diff --git a/characters/instruction-following/Galactica Finetuned.yaml b/instruction-templates/Galactica Finetuned.yaml similarity index 100% rename from characters/instruction-following/Galactica Finetuned.yaml rename to instruction-templates/Galactica Finetuned.yaml diff --git a/characters/instruction-following/Galactica Q.yaml b/instruction-templates/Galactica Q.yaml similarity index 100% rename from characters/instruction-following/Galactica Q.yaml rename to instruction-templates/Galactica Q.yaml diff --git a/characters/instruction-following/Galactica Summary.yaml b/instruction-templates/Galactica Summary.yaml similarity index 100% rename from characters/instruction-following/Galactica Summary.yaml rename to instruction-templates/Galactica Summary.yaml diff --git a/characters/instruction-following/Galactica Work.yaml b/instruction-templates/Galactica Work.yaml similarity index 100% rename from characters/instruction-following/Galactica Work.yaml rename to instruction-templates/Galactica Work.yaml diff --git a/characters/instruction-following/Galactica v2.yaml b/instruction-templates/Galactica v2.yaml similarity index 100% rename from characters/instruction-following/Galactica v2.yaml rename to instruction-templates/Galactica v2.yaml diff --git a/characters/instruction-following/Galactica.yaml b/instruction-templates/Galactica.yaml similarity index 100% rename from characters/instruction-following/Galactica.yaml rename to instruction-templates/Galactica.yaml diff --git a/characters/instruction-following/Gorilla.yaml b/instruction-templates/Gorilla.yaml similarity index 100% rename from characters/instruction-following/Gorilla.yaml rename to instruction-templates/Gorilla.yaml diff --git a/characters/instruction-following/Guanaco non-chat.yaml b/instruction-templates/Guanaco non-chat.yaml similarity index 100% rename from characters/instruction-following/Guanaco non-chat.yaml rename to instruction-templates/Guanaco non-chat.yaml diff --git a/characters/instruction-following/Guanaco-QLoRA.yaml b/instruction-templates/Guanaco-QLoRA.yaml similarity index 100% rename from characters/instruction-following/Guanaco-QLoRA.yaml rename to instruction-templates/Guanaco-QLoRA.yaml diff --git a/characters/instruction-following/Guanaco.yaml b/instruction-templates/Guanaco.yaml similarity index 100% rename from characters/instruction-following/Guanaco.yaml rename to instruction-templates/Guanaco.yaml diff --git a/characters/instruction-following/H2O-human_bot.yaml b/instruction-templates/H2O-human_bot.yaml similarity index 100% rename from characters/instruction-following/H2O-human_bot.yaml rename to instruction-templates/H2O-human_bot.yaml diff --git a/characters/instruction-following/H2O-prompt_answer.yaml b/instruction-templates/H2O-prompt_answer.yaml similarity index 100% rename from characters/instruction-following/H2O-prompt_answer.yaml rename to instruction-templates/H2O-prompt_answer.yaml diff --git a/characters/instruction-following/Hippogriff.yaml b/instruction-templates/Hippogriff.yaml similarity index 100% rename from characters/instruction-following/Hippogriff.yaml rename to instruction-templates/Hippogriff.yaml diff --git a/characters/instruction-following/INCITE-Chat.yaml b/instruction-templates/INCITE-Chat.yaml similarity index 100% rename from characters/instruction-following/INCITE-Chat.yaml rename to instruction-templates/INCITE-Chat.yaml diff --git a/characters/instruction-following/INCITE-Instruct.yaml b/instruction-templates/INCITE-Instruct.yaml similarity index 100% rename from characters/instruction-following/INCITE-Instruct.yaml rename to instruction-templates/INCITE-Instruct.yaml diff --git a/characters/instruction-following/KoAlpaca.yaml b/instruction-templates/KoAlpaca.yaml similarity index 100% rename from characters/instruction-following/KoAlpaca.yaml rename to instruction-templates/KoAlpaca.yaml diff --git a/characters/instruction-following/Koala.yaml b/instruction-templates/Koala.yaml similarity index 100% rename from characters/instruction-following/Koala.yaml rename to instruction-templates/Koala.yaml diff --git a/characters/instruction-following/LLaVA.yaml b/instruction-templates/LLaVA.yaml similarity index 100% rename from characters/instruction-following/LLaVA.yaml rename to instruction-templates/LLaVA.yaml diff --git a/instruction-templates/Llama-v2.yaml b/instruction-templates/Llama-v2.yaml new file mode 100644 index 0000000..d259dd3 --- /dev/null +++ b/instruction-templates/Llama-v2.yaml @@ -0,0 +1,4 @@ +user: "" +bot: "" +turn_template: "<|user|><|user-message|> [/INST] <|bot|><|bot-message|> [INST] " +context: "[INST] <>\nAnswer the questions.\n<>\n\n" diff --git a/characters/instruction-following/MOSS.yaml b/instruction-templates/MOSS.yaml similarity index 100% rename from characters/instruction-following/MOSS.yaml rename to instruction-templates/MOSS.yaml diff --git a/characters/instruction-following/MPT-Chat.yaml b/instruction-templates/MPT-Chat.yaml similarity index 100% rename from characters/instruction-following/MPT-Chat.yaml rename to instruction-templates/MPT-Chat.yaml diff --git a/characters/instruction-following/Manticore Chat.yaml b/instruction-templates/Manticore Chat.yaml similarity index 100% rename from characters/instruction-following/Manticore Chat.yaml rename to instruction-templates/Manticore Chat.yaml diff --git a/characters/instruction-following/Metharme.yaml b/instruction-templates/Metharme.yaml similarity index 100% rename from characters/instruction-following/Metharme.yaml rename to instruction-templates/Metharme.yaml diff --git a/characters/instruction-following/Minotaur.yaml b/instruction-templates/Minotaur.yaml similarity index 100% rename from characters/instruction-following/Minotaur.yaml rename to instruction-templates/Minotaur.yaml diff --git a/instruction-templates/NewHope.yaml b/instruction-templates/NewHope.yaml new file mode 100644 index 0000000..d9a72f6 --- /dev/null +++ b/instruction-templates/NewHope.yaml @@ -0,0 +1,4 @@ +user: "### Instruction:" +bot: "### Response:" +turn_template: "<|user|>\n<|user-message|>\n\n<|bot|>\n<|bot-message|> " +context: " " diff --git a/characters/instruction-following/Open Assistant.yaml b/instruction-templates/Open Assistant.yaml similarity index 100% rename from characters/instruction-following/Open Assistant.yaml rename to instruction-templates/Open Assistant.yaml diff --git a/characters/instruction-following/OpenBuddy.yaml b/instruction-templates/OpenBuddy.yaml similarity index 100% rename from characters/instruction-following/OpenBuddy.yaml rename to instruction-templates/OpenBuddy.yaml diff --git a/instruction-templates/OpenChat.yaml b/instruction-templates/OpenChat.yaml new file mode 100644 index 0000000..3b84c22 --- /dev/null +++ b/instruction-templates/OpenChat.yaml @@ -0,0 +1,4 @@ +user: "GPT4 User:" +bot: "GPT4 Assistant:" +turn_template: "<|user|> <|user-message|><|end_of_turn|><|bot|> <|bot-message|><|end_of_turn|>" +context: "" diff --git a/instruction-templates/OpenOrca-Platypus2.yaml b/instruction-templates/OpenOrca-Platypus2.yaml new file mode 100644 index 0000000..6cac004 --- /dev/null +++ b/instruction-templates/OpenOrca-Platypus2.yaml @@ -0,0 +1,4 @@ +user: "### Instruction:" +bot: "### Response:" +turn_template: "<|user|> <|user-message|>\n\n<|bot|> <|bot-message|>\n\n" +context: "" diff --git a/characters/instruction-following/Orca Mini.yaml b/instruction-templates/Orca Mini.yaml similarity index 100% rename from characters/instruction-following/Orca Mini.yaml rename to instruction-templates/Orca Mini.yaml diff --git a/characters/instruction-following/RWKV-Raven.yaml b/instruction-templates/RWKV-Raven.yaml similarity index 100% rename from characters/instruction-following/RWKV-Raven.yaml rename to instruction-templates/RWKV-Raven.yaml diff --git a/characters/instruction-following/Samantha.yaml b/instruction-templates/Samantha.yaml similarity index 100% rename from characters/instruction-following/Samantha.yaml rename to instruction-templates/Samantha.yaml diff --git a/instruction-templates/StableBeluga2.yaml b/instruction-templates/StableBeluga2.yaml new file mode 100644 index 0000000..cd5675f --- /dev/null +++ b/instruction-templates/StableBeluga2.yaml @@ -0,0 +1,4 @@ +user: "### User:" +bot: "### Assistant:" +turn_template: "<|user|>\n<|user-message|>\n\n<|bot|>\n<|bot-message|>\n\n" +context: "### System:\nThis is a system prompt, please behave and help the user.\n\n" diff --git a/characters/instruction-following/StableLM.yaml b/instruction-templates/StableLM.yaml similarity index 100% rename from characters/instruction-following/StableLM.yaml rename to instruction-templates/StableLM.yaml diff --git a/characters/instruction-following/StableVicuna.yaml b/instruction-templates/StableVicuna.yaml similarity index 100% rename from characters/instruction-following/StableVicuna.yaml rename to instruction-templates/StableVicuna.yaml diff --git a/characters/instruction-following/Starchat-Beta.yaml b/instruction-templates/Starchat-Beta.yaml similarity index 100% rename from characters/instruction-following/Starchat-Beta.yaml rename to instruction-templates/Starchat-Beta.yaml diff --git a/characters/instruction-following/Tulu.yaml b/instruction-templates/Tulu.yaml similarity index 100% rename from characters/instruction-following/Tulu.yaml rename to instruction-templates/Tulu.yaml diff --git a/characters/instruction-following/Vicuna-v0.yaml b/instruction-templates/Vicuna-v0.yaml similarity index 100% rename from characters/instruction-following/Vicuna-v0.yaml rename to instruction-templates/Vicuna-v0.yaml diff --git a/characters/instruction-following/Vicuna-v1.1.yaml b/instruction-templates/Vicuna-v1.1.yaml similarity index 100% rename from characters/instruction-following/Vicuna-v1.1.yaml rename to instruction-templates/Vicuna-v1.1.yaml diff --git a/characters/instruction-following/Vigogne-Chat.yaml b/instruction-templates/Vigogne-Chat.yaml similarity index 100% rename from characters/instruction-following/Vigogne-Chat.yaml rename to instruction-templates/Vigogne-Chat.yaml diff --git a/characters/instruction-following/Vigogne-Instruct.yaml b/instruction-templates/Vigogne-Instruct.yaml similarity index 100% rename from characters/instruction-following/Vigogne-Instruct.yaml rename to instruction-templates/Vigogne-Instruct.yaml diff --git a/characters/instruction-following/Wizard-Mega ShareGPT.yaml b/instruction-templates/Wizard-Mega ShareGPT.yaml similarity index 100% rename from characters/instruction-following/Wizard-Mega ShareGPT.yaml rename to instruction-templates/Wizard-Mega ShareGPT.yaml diff --git a/characters/instruction-following/Wizard-Mega WizardLM.yaml b/instruction-templates/Wizard-Mega WizardLM.yaml similarity index 100% rename from characters/instruction-following/Wizard-Mega WizardLM.yaml rename to instruction-templates/Wizard-Mega WizardLM.yaml diff --git a/characters/instruction-following/Wizard-Mega.yaml b/instruction-templates/Wizard-Mega.yaml similarity index 100% rename from characters/instruction-following/Wizard-Mega.yaml rename to instruction-templates/Wizard-Mega.yaml diff --git a/characters/instruction-following/Ziya.yaml b/instruction-templates/Ziya.yaml similarity index 100% rename from characters/instruction-following/Ziya.yaml rename to instruction-templates/Ziya.yaml diff --git a/js/main.js b/js/main.js new file mode 100644 index 0000000..2155794 --- /dev/null +++ b/js/main.js @@ -0,0 +1,330 @@ +let main_parent = document.getElementById('chat-tab').parentNode; +let extensions = document.getElementById('extensions'); + +main_parent.childNodes[0].classList.add("header_bar"); +main_parent.style = "padding: 0; margin: 0"; +main_parent.parentNode.style = "gap: 0"; +main_parent.parentNode.parentNode.style = "padding: 0"; + +document.querySelector('.header_bar').addEventListener('click', function(event) { + if (event.target.tagName === 'BUTTON') { + const buttonText = event.target.textContent.trim(); + + let chat_visible = (buttonText == 'Chat'); + let default_visible = (buttonText == 'Default'); + let notebook_visible = (buttonText == 'Notebook'); + + // Check if one of the generation tabs is visible + if (chat_visible || notebook_visible || default_visible) { + extensions.style.display = 'flex'; + if (chat_visible) { + extensions.style.maxWidth = "880px"; + extensions.style.padding = "0px"; + } else { + extensions.style.maxWidth = "none"; + extensions.style.padding = "15px"; + } + } else { + extensions.style.display = 'none'; + } + } +}); + +//------------------------------------------------ +// Keyboard shortcuts +//------------------------------------------------ +document.addEventListener("keydown", function(event) { + + // Stop generation on Esc pressed + if (event.key === "Escape") { + // Find the element with id 'stop' and click it + var stopButton = document.getElementById("stop"); + if (stopButton) { + stopButton.click(); + } + } + + // Show chat controls on Ctrl + S + else if (event.ctrlKey && event.key == "s") { + event.preventDefault(); + + var showControlsElement = document.getElementById('show-controls'); + if (showControlsElement && showControlsElement.childNodes.length >= 4) { + showControlsElement.childNodes[3].click(); + + var arr = document.getElementById('chat-input').childNodes[2].childNodes; + arr[arr.length - 1].focus(); + } + } + + // Regenerate on Ctrl + Enter + else if (event.ctrlKey && event.key === 'Enter') { + event.preventDefault(); + document.getElementById('Regenerate').click(); + } + + // Continue on Alt + Enter + else if (event.altKey && event.key === 'Enter') { + event.preventDefault(); + document.getElementById('Continue').click(); + } + + // Remove last on Ctrl + Shift + Backspace + else if (event.ctrlKey && event.shiftKey && event.key === 'Backspace') { + event.preventDefault(); + document.getElementById('Remove-last').click(); + } + + // Copy last on Ctrl + Shift + K + else if (event.ctrlKey && event.shiftKey && event.key === 'K') { + event.preventDefault(); + document.getElementById('Copy-last').click(); + } + + // Replace last on Ctrl + Shift + L + else if (event.ctrlKey && event.shiftKey && event.key === 'L') { + event.preventDefault(); + document.getElementById('Replace-last').click(); + } + + // Impersonate on Ctrl + Shift + M + else if (event.ctrlKey && event.shiftKey && event.key === 'M') { + event.preventDefault(); + document.getElementById('Impersonate').click(); + } + +}); + +//------------------------------------------------ +// Position the chat typing dots +//------------------------------------------------ +typing = document.getElementById('typing-container'); +typingParent = typing.parentNode; +typingSibling = typing.previousElementSibling; +typingSibling.insertBefore(typing, typingSibling.childNodes[2]); + +//------------------------------------------------ +// Chat scrolling +//------------------------------------------------ +const targetElement = document.getElementById('chat').parentNode.parentNode.parentNode; +targetElement.classList.add('pretty_scrollbar'); +targetElement.classList.add('chat-parent'); +let isScrolled = false; + +targetElement.addEventListener('scroll', function() { + let diff = targetElement.scrollHeight - targetElement.clientHeight; + if(Math.abs(targetElement.scrollTop - diff) <= 10 || diff == 0) { + isScrolled = false; + } else { + isScrolled = true; + } +}); + +// Create a MutationObserver instance +const observer = new MutationObserver(function(mutations) { + mutations.forEach(function(mutation) { + if(!isScrolled) { + targetElement.scrollTop = targetElement.scrollHeight; + } + + const firstChild = targetElement.children[0]; + if (firstChild.classList.contains('generating')) { + typing.parentNode.classList.add('visible-dots'); + document.getElementById('stop').style.display = 'flex'; + document.getElementById('Generate').style.display = 'none'; + } else { + typing.parentNode.classList.remove('visible-dots'); + document.getElementById('stop').style.display = 'none'; + document.getElementById('Generate').style.display = 'flex'; + } + + }); +}); + +// Configure the observer to watch for changes in the subtree and attributes +const config = { + childList: true, + subtree: true, + characterData: true, + attributeOldValue: true, + characterDataOldValue: true +}; + +// Start observing the target element +observer.observe(targetElement, config); + +//------------------------------------------------ +// Notebook box scrolling +//------------------------------------------------ +const notebookElement = document.querySelector('#textbox-notebook textarea'); +let notebookScrolled = false; + +notebookElement.addEventListener('scroll', function() { + let diff = notebookElement.scrollHeight - notebookElement.clientHeight; + if(Math.abs(notebookElement.scrollTop - diff) <= 10 || diff == 0) { + notebookScrolled = false; + } else { + notebookScrolled = true; + } +}); + +const notebookObserver = new MutationObserver(function(mutations) { + mutations.forEach(function(mutation) { + if(!notebookScrolled) { + notebookElement.scrollTop = notebookElement.scrollHeight; + } + }); +}); + +notebookObserver.observe(notebookElement.parentNode.parentNode.parentNode, config); + +//------------------------------------------------ +// Default box scrolling +//------------------------------------------------ +const defaultElement = document.querySelector('#textbox-default textarea'); +let defaultScrolled = false; + +defaultElement.addEventListener('scroll', function() { + let diff = defaultElement.scrollHeight - defaultElement.clientHeight; + if(Math.abs(defaultElement.scrollTop - diff) <= 10 || diff == 0) { + defaultScrolled = false; + } else { + defaultScrolled = true; + } +}); + +const defaultObserver = new MutationObserver(function(mutations) { + mutations.forEach(function(mutation) { + if(!defaultScrolled) { + defaultElement.scrollTop = defaultElement.scrollHeight; + } + }); +}); + +defaultObserver.observe(defaultElement.parentNode.parentNode.parentNode, config); + +//------------------------------------------------ +// Add some scrollbars +//------------------------------------------------ +const textareaElements = document.querySelectorAll('.add_scrollbar textarea'); +for(i = 0; i < textareaElements.length; i++) { + textareaElements[i].classList.remove('scroll-hide'); + textareaElements[i].classList.add('pretty_scrollbar'); + textareaElements[i].style.resize = "none"; +} + +//------------------------------------------------ +// Remove some backgrounds +//------------------------------------------------ +const noBackgroundelements = document.querySelectorAll('.no-background'); +for(i = 0; i < noBackgroundelements.length; i++) { + noBackgroundelements[i].parentNode.style.border = 'none'; + noBackgroundelements[i].parentNode.parentNode.parentNode.style.alignItems = 'center'; +} + +//------------------------------------------------ +// Create the hover menu in the chat tab +// The show/hide events were adapted from: +// https://github.com/SillyTavern/SillyTavern/blob/6c8bd06308c69d51e2eb174541792a870a83d2d6/public/script.js +//------------------------------------------------ +var buttonsInChat = document.querySelectorAll("#chat-tab:not(.old-ui) #chat-buttons button"); +var button = document.getElementById('hover-element-button'); +var menu = document.getElementById('hover-menu'); + +function showMenu() { + menu.style.display = 'flex'; // Show the menu +} + +function hideMenu() { + menu.style.display = 'none'; // Hide the menu + document.querySelector('#chat-input textarea').focus(); +} + +if (buttonsInChat.length > 0) { + for (let i = buttonsInChat.length - 1; i >= 0; i--) { + const thisButton = buttonsInChat[i]; + menu.appendChild(thisButton); + + thisButton.addEventListener("click", () => { + hideMenu(); + }); + + const buttonText = thisButton.textContent; + const matches = buttonText.match(/(\(.*?\))/); + + if (matches && matches.length > 1) { + // Apply the transparent-substring class to the matched substring + const substring = matches[1]; + const newText = buttonText.replace(substring, ` ${substring.slice(1, -1)}`); + thisButton.innerHTML = newText; + } + } +} else { + buttonsInChat = document.querySelectorAll("#chat-tab.old-ui #chat-buttons button"); + for (let i = 0; i < buttonsInChat.length; i++) { + buttonsInChat[i].textContent = buttonsInChat[i].textContent.replace(/ \(.*?\)/, ''); + } + document.getElementById('gr-hover-container').style.display = 'none'; +} + +function isMouseOverButtonOrMenu() { + return menu.matches(':hover') || button.matches(':hover'); +} + +button.addEventListener('mouseenter', function () { + showMenu(); +}); + +button.addEventListener('click', function () { + showMenu(); +}); + +// Add event listener for mouseleave on the button +button.addEventListener('mouseleave', function () { + // Delay to prevent menu hiding when the mouse leaves the button into the menu + setTimeout(function () { + if (!isMouseOverButtonOrMenu()) { + hideMenu(); + } + }, 100); +}); + +// Add event listener for mouseleave on the menu +menu.addEventListener('mouseleave', function () { + // Delay to prevent menu hide when the mouse leaves the menu into the button + setTimeout(function () { + if (!isMouseOverButtonOrMenu()) { + hideMenu(); + } + }, 100); +}); + +// Add event listener for click anywhere in the document +document.addEventListener('click', function (event) { + // Check if the click is outside the button/menu and the menu is visible + if (!isMouseOverButtonOrMenu() && menu.style.display === 'flex') { + hideMenu(); + } +}); + +//------------------------------------------------ +// Relocate the "Show controls" checkbox +//------------------------------------------------ +var elementToMove = document.getElementById('show-controls'); +var parent = elementToMove.parentNode; +for (var i = 0; i < 2; i++) { + parent = parent.parentNode; +} + +parent.insertBefore(elementToMove, parent.firstChild); + +//------------------------------------------------ +// Make the chat input grow upwards instead of downwards +//------------------------------------------------ +document.getElementById('show-controls').parentNode.style.position = 'absolute'; +document.getElementById('show-controls').parentNode.style.bottom = '0px'; + +//------------------------------------------------ +// Focus on the chat input +//------------------------------------------------ +document.querySelector('#chat-input textarea').focus() diff --git a/js/save_files.js b/js/save_files.js new file mode 100644 index 0000000..d5b22c4 --- /dev/null +++ b/js/save_files.js @@ -0,0 +1,40 @@ +// Functions for downloading JSON files +function getCurrentTimestamp() { + const now = new Date(); + const timezoneOffset = now.getTimezoneOffset() * 60000; // Convert to milliseconds + const localTime = new Date(now.getTime() - timezoneOffset); + const formattedTimestamp = localTime.toISOString().replace(/[-:]/g, '').slice(0, 15); + return formattedTimestamp; +} + +function saveFile(contents, filename) { + const element = document.createElement('a'); + element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(contents)); + element.setAttribute('download', filename); + element.style.display = 'none'; + document.body.appendChild(element); + element.click(); + document.body.removeChild(element); +} + +function saveHistory(history, character, mode) { + let path = null; + + if (['chat', 'chat-instruct'].includes(mode) && character && character.trim() !== '') { + path = `history_${character}_${getCurrentTimestamp()}.json`; + } else { + try { + path = `history_${mode}_${getCurrentTimestamp()}.json`; + } catch (error) { + path = `history_${getCurrentTimestamp()}.json`; + } + } + saveFile(history, path); +} + +function saveSession(session) { + let path = null; + + path = `session_${getCurrentTimestamp()}.json`; + saveFile(session, path); +} diff --git a/js/show_controls.js b/js/show_controls.js new file mode 100644 index 0000000..b35463b --- /dev/null +++ b/js/show_controls.js @@ -0,0 +1,22 @@ +const belowChatInput = document.querySelectorAll("#chat-tab > div > :nth-child(n+2), #extensions"); +const chatParent = document.querySelector(".chat-parent"); + +function toggle_controls(value) { + if (value) { + belowChatInput.forEach(element => { + element.style.display = "inherit"; + }); + + chatParent.classList.remove("bigchat"); + document.getElementById('chat-input-row').classList.remove("bigchat"); + document.getElementById('chat-col').classList.remove("bigchat"); + } else { + belowChatInput.forEach(element => { + element.style.display = "none"; + }); + + chatParent.classList.add("bigchat"); + document.getElementById('chat-input-row').classList.add("bigchat") + document.getElementById('chat-col').classList.add("bigchat"); + } +} diff --git a/js/switch_tabs.js b/js/switch_tabs.js new file mode 100644 index 0000000..e49fef4 --- /dev/null +++ b/js/switch_tabs.js @@ -0,0 +1,59 @@ +let chat_tab = document.getElementById('chat-tab'); +let main_parent = chat_tab.parentNode; + +function scrollToTop() { + window.scrollTo({ + top: 0, + // behavior: 'smooth' + }); +} + +function findButtonsByText(buttonText) { + const buttons = document.getElementsByTagName('button'); + const matchingButtons = []; + buttonText = buttonText.trim(); + + for (let i = 0; i < buttons.length; i++) { + const button = buttons[i]; + const buttonInnerText = button.textContent.trim(); + + if (buttonInnerText === buttonText) { + matchingButtons.push(button); + } + } + + return matchingButtons; +} + +function switch_to_chat() { + let chat_tab_button = main_parent.childNodes[0].childNodes[1]; + chat_tab_button.click(); + scrollToTop(); +} + +function switch_to_default() { + let default_tab_button = main_parent.childNodes[0].childNodes[4]; + default_tab_button.click(); + scrollToTop(); +} + +function switch_to_notebook() { + let notebook_tab_button = main_parent.childNodes[0].childNodes[7]; + notebook_tab_button.click(); + findButtonsByText('Raw')[1].click() + scrollToTop(); +} + +function switch_to_generation_parameters() { + let parameters_tab_button = main_parent.childNodes[0].childNodes[10]; + parameters_tab_button.click(); + findButtonsByText('Generation')[0].click() + scrollToTop(); +} + +function switch_to_character() { + let parameters_tab_button = main_parent.childNodes[0].childNodes[10]; + parameters_tab_button.click(); + findButtonsByText('Character')[0].click() + scrollToTop(); +} diff --git a/models/config.yaml b/models/config.yaml index d81eac9..d98e95f 100644 --- a/models/config.yaml +++ b/models/config.yaml @@ -1,15 +1,27 @@ -.*(llama|alpac|vicuna|guanaco|koala|llava|wizardlm|metharme|pygmalion-7b|wizard-mega|openbuddy|vigogne|h2ogpt-research|manticore): +.*(llama|alpac|vicuna|guanaco|koala|llava|wizardlm|metharme|pygmalion-7b|pygmalion-2|mythalion|wizard-mega|openbuddy|vigogne|h2ogpt-research|manticore): model_type: 'llama' .*(opt-|opt_|opt1|opt3|optfor|galactica|galpaca|pygmalion-350m): model_type: 'opt' .*(gpt-j|gptj|gpt4all-j|malion-6b|pygway|pygmalion-6b|dolly-v1): model_type: 'gptj' .*(gpt-neox|koalpaca-polyglot|polyglot.*koalpaca|polyglot-ko|polyglot_ko|pythia|stablelm|incite|dolly-v2|polycoder|h2ogpt-oig|h2ogpt-oasst1|h2ogpt-gm): - model_type: 'gpt_neox' + model_type: 'gptneox' .*llama: model_type: 'llama' .*bloom: model_type: 'bloom' +.*gpt2: + model_type: 'gpt2' +.*falcon: + model_type: 'falcon' +.*mpt: + model_type: 'mpt' +.*(starcoder|starchat): + model_type: 'starcoder' +.*dolly-v2: + model_type: 'dollyv2' +.*replit: + model_type: 'replit' llama-65b-gptq-3bit: groupsize: 'None' .*(4bit|int4): @@ -37,204 +49,141 @@ llama-65b-gptq-3bit: .*(gr1024|1024g|groupsize1024): groupsize: 1024 .*(oasst|openassistant-|stablelm-7b-sft-v7-epoch-3): - mode: 'instruct' instruction_template: 'Open Assistant' skip_special_tokens: false (?!.*galactica)(?!.*reward).*openassistant: - mode: 'instruct' instruction_template: 'Open Assistant' skip_special_tokens: false (?!.*v0)(?!.*1.1)(?!.*1_1)(?!.*stable)(?!.*chinese).*vicuna: - mode: 'instruct' instruction_template: 'Vicuna-v0' .*vicuna.*v0: - mode: 'instruct' instruction_template: 'Vicuna-v0' .*vicuna.*(1.1|1_1|1.3|1_3): - mode: 'instruct' instruction_template: 'Vicuna-v1.1' -.*wizard.*vicuna: - mode: 'instruct' +.*vicuna.*(1.5|1_5): instruction_template: 'Vicuna-v1.1' + truncation_length: 4096 .*stable.*vicuna: - mode: 'instruct' instruction_template: 'StableVicuna' (?!.*chat).*chinese-vicuna: - mode: 'instruct' instruction_template: 'Alpaca' .*chinese-vicuna.*chat: - mode: 'instruct' instruction_template: 'Chinese-Vicuna-Chat' .*alpaca: - mode: 'instruct' instruction_template: 'Alpaca' .*alpaca-native-4bit: - mode: 'instruct' instruction_template: 'Alpaca' wbits: 4 groupsize: 128 .*galactica: skip_special_tokens: false .*dolly-v[0-9]-[0-9]*b: - mode: 'instruct' instruction_template: 'Alpaca' skip_special_tokens: false custom_stopping_strings: '"### End"' .*koala: - mode: 'instruct' instruction_template: 'Koala' .*chatglm: - mode: 'instruct' instruction_template: 'ChatGLM' -.*metharme: - mode: 'instruct' +.*(metharme|pygmalion|mythalion): instruction_template: 'Metharme' .*llava: - mode: 'instruct' model_type: 'llama' instruction_template: 'LLaVA' custom_stopping_strings: '"\n###"' .*raven: - mode: 'instruct' instruction_template: 'RWKV-Raven' +.*ctx8192: + truncation_length: 8192 .*moss-moon.*sft: - mode: 'instruct' instruction_template: 'MOSS' .*stablelm-tuned: - mode: 'instruct' instruction_template: 'StableLM' truncation_length: 4096 .*stablelm-base: truncation_length: 4096 -.*wizardlm: - mode: 'instruct' - model_type: 'llama' - instruction_template: 'WizardLM' .*galactica.*finetuned: - mode: 'instruct' instruction_template: 'Galactica Finetuned' .*galactica.*-v2: - mode: 'instruct' instruction_template: 'Galactica v2' (?!.*finetuned)(?!.*-v2).*galactica: - mode: 'instruct' instruction_template: 'Galactica' .*guanaco: - mode: 'instruct' instruction_template: 'Guanaco non-chat' .*baize: - mode: 'instruct' instruction_template: 'Baize' .*mpt-.*instruct: - mode: 'instruct' instruction_template: 'Alpaca' .*mpt-.*chat: - mode: 'instruct' instruction_template: 'MPT-Chat' (?!.*-flan-)(?!.*-t5-).*lamini-: - mode: 'instruct' instruction_template: 'Alpaca' .*incite.*chat: - mode: 'instruct' instruction_template: 'INCITE-Chat' .*incite.*instruct: - mode: 'instruct' instruction_template: 'INCITE-Instruct' .*wizard.*mega: - mode: 'instruct' instruction_template: 'Wizard-Mega' + custom_stopping_strings: '""' .*ziya-: - mode: 'instruct' instruction_template: 'Ziya' .*koalpaca: - mode: 'instruct' instruction_template: 'KoAlpaca' .*openbuddy: - mode: 'instruct' instruction_template: 'OpenBuddy' (?!.*chat).*vigogne: - mode: 'instruct' instruction_template: 'Vigogne-Instruct' .*vigogne.*chat: - mode: 'instruct' instruction_template: 'Vigogne-Chat' .*(llama-deus|supercot|llama-natural-instructions|open-llama-0.3t-7b-instruct-dolly-hhrlhf|open-llama-0.3t-7b-open-instruct): - mode: 'instruct' instruction_template: 'Alpaca' .*bactrian: - mode: 'instruct' instruction_template: 'Bactrian' .*(h2ogpt-oig-|h2ogpt-oasst1-|h2ogpt-research-oasst1-): - mode: 'instruct' instruction_template: 'H2O-human_bot' .*h2ogpt-gm-: - mode: 'instruct' instruction_template: 'H2O-prompt_answer' .*manticore: - mode: 'instruct' instruction_template: 'Manticore Chat' .*bluemoonrp-(30|13)b: - mode: 'instruct' instruction_template: 'Bluemoon' truncation_length: 4096 .*Nous-Hermes-13b: - mode: 'instruct' instruction_template: 'Alpaca' .*airoboros: - mode: 'instruct' - instruction_template: 'Vicuna-v1.1' -.*WizardLM-30B-V1.0: - mode: 'instruct' - instruction_template: 'Vicuna-v1.1' -TheBloke_WizardLM-30B-GPTQ: - mode: 'instruct' instruction_template: 'Vicuna-v1.1' +.*airoboros.*1.2: + instruction_template: 'Airoboros-v1.2' .*alpa(cino|sta): - mode: 'instruct' instruction_template: 'Alpaca' .*hippogriff: - mode: 'instruct' instruction_template: 'Hippogriff' -.*gpt4all-.*-snoozy: - mode: 'instruct' - instruction_template: 'WizardLM' .*lazarus: - mode: 'instruct' instruction_template: 'Alpaca' .*guanaco-.*(7|13|33|65)b: - mode: 'instruct' instruction_template: 'Guanaco' .*hypermantis: - mode: 'instruct' instruction_template: 'Alpaca' .*open-llama-.*-open-instruct: - mode: 'instruct' instruction_template: 'Alpaca' .*starcoder-gpteacher-code-instruct: - mode: 'instruct' instruction_template: 'Alpaca' .*tulu: - mode: 'instruct' instruction_template: 'Tulu' .*chronos: - mode: 'instruct' instruction_template: 'Alpaca' .*samantha: - mode: 'instruct' instruction_template: 'Samantha' .*wizardcoder: - mode: 'instruct' instruction_template: 'Alpaca' .*starchat-beta: - mode: 'instruct' instruction_template: 'Starchat-Beta' + custom_stopping_strings: '"<|end|>"' .*minotaur: - mode: 'instruct' instruction_template: 'Minotaur' .*minotaur-15b: truncation_length: 8192 .*orca_mini: - mode: 'instruct' instruction_template: 'Orca Mini' .*landmark: truncation_length: 8192 @@ -243,3 +192,38 @@ TheBloke_WizardLM-30B-GPTQ: .*xgen.*-inst: truncation_length: 8192 instruction_template: 'Vicuna-v0' +.*(platypus|gplatty|superplatty): + instruction_template: 'Alpaca' +.*longchat: + instruction_template: 'Vicuna-v1.1' +.*vicuna-33b: + instruction_template: 'Vicuna-v1.1' +.*redmond-hermes-coder: + instruction_template: 'Alpaca' + truncation_length: 8192 +.*wizardcoder-15b: + instruction_template: 'Alpaca' + truncation_length: 8192 +.*wizardlm: + instruction_template: 'Vicuna-v1.1' +.*godzilla: + instruction_template: 'Alpaca' +.*llama-(2|v2): + truncation_length: 4096 +.*llama(-?)(2|v2).*chat: + instruction_template: 'Llama-v2' +.*newhope: + instruction_template: 'NewHope' +.*stablebeluga2: + instruction_template: 'StableBeluga2' + truncation_length: 4096 +.*openchat: + instruction_template: 'OpenChat' +.*falcon.*-instruct: +.*(openorca-platypus2): + instruction_template: 'OpenOrca-Platypus2' + custom_stopping_strings: '"### Instruction:", "### Response:"' +.*codellama: + rope_freq_base: 1000000 +.*codellama.*instruct: + instruction_template: 'Llama-v2' \ No newline at end of file diff --git a/modules/AutoGPTQ_loader.py b/modules/AutoGPTQ_loader.py index 0d41ac0..987f5ba 100644 --- a/modules/AutoGPTQ_loader.py +++ b/modules/AutoGPTQ_loader.py @@ -50,6 +50,7 @@ def load_quantized(model_name): 'max_memory': get_max_memory_dict(), 'quantize_config': quantize_config, 'use_cuda_fp16': not shared.args.no_use_cuda_fp16, + 'disable_exllama': shared.args.disable_exllama, } logger.info(f"The AutoGPTQ params are: {params}") diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index ddc5f9a..bc528b1 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -1,6 +1,5 @@ import inspect import re -import sys from pathlib import Path import accelerate @@ -11,26 +10,9 @@ from transformers import AutoConfig, AutoModelForCausalLM import modules.shared as shared from modules.logging_colors import logger -sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) - -try: - import llama_inference_offload -except ImportError: - logger.error('Failed to load GPTQ-for-LLaMa') - logger.error('See https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md') - sys.exit(-1) - -try: - from modelutils import find_layers -except ImportError: - from utils import find_layers - -try: - from quant import make_quant - is_triton = False -except ImportError: - import quant - is_triton = True +from gptq_for_llama import llama_inference_offload +from gptq_for_llama.modelutils import find_layers +from gptq_for_llama.quant import make_quant # This function is a replacement for the load_quant function in the @@ -59,24 +41,21 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc if name in layers: del layers[name] - if not is_triton: - gptq_args = inspect.getfullargspec(make_quant).args + gptq_args = inspect.getfullargspec(make_quant).args - make_quant_kwargs = { - 'module': model, - 'names': layers, - 'bits': wbits, - } - if 'groupsize' in gptq_args: - make_quant_kwargs['groupsize'] = groupsize - if 'faster' in gptq_args: - make_quant_kwargs['faster'] = faster_kernel - if 'kernel_switch_threshold' in gptq_args: - make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold + make_quant_kwargs = { + 'module': model, + 'names': layers, + 'bits': wbits, + } + if 'groupsize' in gptq_args: + make_quant_kwargs['groupsize'] = groupsize + if 'faster' in gptq_args: + make_quant_kwargs['faster'] = faster_kernel + if 'kernel_switch_threshold' in gptq_args: + make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold - make_quant(**make_quant_kwargs) - else: - quant.make_quant_linear(model, layers, wbits, groupsize) + make_quant(**make_quant_kwargs) del layers if checkpoint.endswith('.safetensors'): @@ -85,18 +64,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc else: model.load_state_dict(torch.load(checkpoint), strict=False) - if is_triton: - if shared.args.quant_attn: - quant.make_quant_attn(model) - - if eval and shared.args.fused_mlp: - quant.make_fused_mlp(model) - - if shared.args.warmup_autotune: - quant.autotune_warmup_linear(model, transpose=not eval) - if eval and shared.args.fused_mlp: - quant.autotune_warmup_fused(model) - model.seqlen = 2048 return model diff --git a/modules/LoRA.py b/modules/LoRA.py index 2eade07..1002055 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -17,6 +17,14 @@ def add_lora_to_model(lora_names): add_lora_transformers(lora_names) +def get_lora_path(lora_name): + p = Path(lora_name) + if p.exists(): + lora_name = p.parts[-1] + + return Path(f"{shared.args.lora_dir}/{lora_name}") + + def add_lora_exllama(lora_names): try: @@ -40,7 +48,7 @@ def add_lora_exllama(lora_names): if len(lora_names) > 1: logger.warning('ExLlama can only work with 1 LoRA at the moment. Only the first one in the list will be loaded.') - lora_path = Path(f"{shared.args.lora_dir}/{lora_names[0]}") + lora_path = get_lora_path(lora_names[0]) lora_config_path = lora_path / "adapter_config.json" lora_adapter_path = lora_path / "adapter_model.bin" @@ -66,7 +74,7 @@ def add_lora_autogptq(lora_names): logger.error("This version of AutoGPTQ does not support LoRA. You need to install from source or wait for a new release.") return - if len(lora_names) == 0: + if len(lora_names) == 0: reload_model() shared.lora_names = [] @@ -81,7 +89,7 @@ def add_lora_autogptq(lora_names): inference_mode=True, ) - lora_path = Path(f"{shared.args.lora_dir}/{lora_names[0]}") + lora_path = get_lora_path(lora_names[0]) logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join([lora_names[0]]))) shared.model = get_gptq_peft_model(shared.model, peft_config, lora_path) shared.lora_names = [lora_names[0]] @@ -101,21 +109,21 @@ def add_lora_transformers(lora_names): if len(removed_set) == 0 and len(prior_set) > 0: logger.info(f"Adding the LoRA(s) named {added_set} to the model...") for lora in added_set: - shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora) + shared.model.load_adapter(get_lora_path(lora), lora) return # If any LoRA needs to be removed, start over if len(removed_set) > 0: # shared.model may no longer be PeftModel - if hasattr(shared.model, 'disable_adapter'): - shared.model.disable_adapter() + if hasattr(shared.model, 'disable_adapter'): + shared.model.disable_adapter() shared.model = shared.model.base_model.model if len(lora_names) > 0: params = {} if not shared.args.cpu: - if shared.args.load_in_4bit or shared.args.load_in_8bit: + if shared.args.load_in_4bit or shared.args.load_in_8bit: params['peft_type'] = shared.model.dtype else: params['dtype'] = shared.model.dtype @@ -123,16 +131,16 @@ def add_lora_transformers(lora_names): params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()} logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names))) - shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), adapter_name=lora_names[0], **params) + shared.model = PeftModel.from_pretrained(shared.model, get_lora_path(lora_names[0]), adapter_name=lora_names[0], **params) for lora in lora_names[1:]: - shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora) + shared.model.load_adapter(get_lora_path(lora), lora) shared.lora_names = lora_names if not shared.args.load_in_8bit and not shared.args.cpu: shared.model.half() if not hasattr(shared.model, "hf_device_map"): - if torch.has_mps: + if torch.backends.mps.is_available(): device = torch.device('mps') shared.model = shared.model.to(device) else: diff --git a/modules/RoPE.py b/modules/RoPE.py new file mode 100644 index 0000000..c15616c --- /dev/null +++ b/modules/RoPE.py @@ -0,0 +1,18 @@ +def get_alpha_value(alpha, base): + ''' + Gets alpha_value from alpha_value and rope_freq_base + ''' + if base > 0: + return (base/10000.) ** (63/64.) + else: + return alpha + + +def get_rope_freq_base(alpha, base): + ''' + Gets rope_freq_base from alpha_value and rope_freq_base + ''' + if base > 0: + return base + else: + return 10000 * alpha ** (64/63.) diff --git a/modules/callbacks.py b/modules/callbacks.py index 1fa95e4..e29e397 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -24,6 +24,7 @@ class Stream(transformers.StoppingCriteria): def __call__(self, input_ids, scores) -> bool: if self.callback_func is not None: self.callback_func(input_ids[0]) + return False diff --git a/modules/chat.py b/modules/chat.py index c0635c2..334693a 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -1,8 +1,10 @@ import base64 import copy import functools +import html import json import re +from datetime import datetime from pathlib import Path import gradio as gr @@ -26,6 +28,22 @@ from modules.utils import ( ) +def str_presenter(dumper, data): + """ + Copied from https://github.com/yaml/pyyaml/issues/240 + Makes pyyaml output prettier multiline strings. + """ + + if data.count('\n') > 0: + return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + + return dumper.represent_scalar('tag:yaml.org,2002:str', data) + + +yaml.add_representer(str, str_presenter) +yaml.representer.SafeRepresenter.add_representer(str, str_presenter) + + def get_turn_substrings(state, instruct=False): if instruct: if 'turn_template' not in state or state['turn_template'] == '': @@ -86,10 +104,19 @@ def generate_chat_prompt(user_input, state, **kwargs): else: wrapper = '<|prompt|>' + if is_instruct: + context = state['context_instruct'] + else: + context = replace_character_names( + f"{state['context'].strip()}\n", + state['name1'], + state['name2'] + ) + # Build the prompt + rows = [context] min_rows = 3 i = len(history) - 1 - rows = [state['context_instruct'] if is_instruct else f"{state['context'].strip()}\n"] while i >= 0 and get_encoded_length(wrapper.replace('<|prompt|>', ''.join(rows))) < max_length: if _continue and i == len(history) - 1: if state['mode'] != 'chat-instruct': @@ -150,8 +177,8 @@ def get_stopping_strings(state): f"\n{state['name2']}:" ] - if state['stop_at_newline']: - stopping_strings.append("\n") + if 'stopping_strings' in state and isinstance(state['stopping_strings'], list): + stopping_strings += state.pop('stopping_strings') return stopping_strings @@ -166,19 +193,18 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess yield output return - # Defining some variables just_started = True visible_text = None stopping_strings = get_stopping_strings(state) is_stream = state['stream'] - # Preparing the input + # Prepare the input if not any((regenerate, _continue)): - text, visible_text = apply_extensions('input_hijack', text, visible_text) - if visible_text is None: - visible_text = text + visible_text = html.escape(text) - text = apply_extensions('input', text, state) + # Apply extensions + text, visible_text = apply_extensions('chat_input', text, visible_text, state) + text = apply_extensions('input', text, state, is_chat=True) # *Is typing...* if loading_message: @@ -188,6 +214,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess if regenerate: output['visible'].pop() output['internal'].pop() + # *Is typing...* if loading_message: yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']} @@ -196,86 +223,67 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess if loading_message: yield {'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']], 'internal': output['internal']} - # Generating the prompt + # Generate the prompt kwargs = { '_continue': _continue, 'history': output, } - prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs) if prompt is None: prompt = generate_chat_prompt(text, state, **kwargs) # Generate - cumulative_reply = '' - for i in range(state['chat_generation_attempts']): - reply = None - for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, stopping_strings=stopping_strings, is_chat=True)): - reply = cumulative_reply + reply + reply = None + for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True)): - # Extract the reply - visible_reply = re.sub("(||{{user}})", state['name1'], reply) + # Extract the reply + visible_reply = re.sub("(||{{user}})", state['name1'], reply) + visible_reply = html.escape(visible_reply) - # We need this global variable to handle the Stop event, - # otherwise gradio gets confused - if shared.stop_everything: - output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state) + if shared.stop_everything: + output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True) + yield output + return + + if just_started: + just_started = False + if not _continue: + output['internal'].append(['', '']) + output['visible'].append(['', '']) + + if _continue: + output['internal'][-1] = [text, last_reply[0] + reply] + output['visible'][-1] = [visible_text, last_reply[1] + visible_reply] + if is_stream: + yield output + elif not (j == 0 and visible_reply.strip() == ''): + output['internal'][-1] = [text, reply.lstrip(' ')] + output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')] + if is_stream: yield output - return - if just_started: - just_started = False - if not _continue: - output['internal'].append(['', '']) - output['visible'].append(['', '']) - - if _continue: - output['internal'][-1] = [text, last_reply[0] + reply] - output['visible'][-1] = [visible_text, last_reply[1] + visible_reply] - if is_stream: - yield output - elif not (j == 0 and visible_reply.strip() == ''): - output['internal'][-1] = [text, reply.lstrip(' ')] - output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')] - if is_stream: - yield output - - if reply in [None, cumulative_reply]: - break - else: - cumulative_reply = reply - - output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state) + output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True) yield output -def impersonate_wrapper(text, start_with, state): +def impersonate_wrapper(text, state): + + static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style']) + if shared.model_name == 'None' or shared.model is None: logger.error("No model is loaded! Select one in the Model tab.") - yield '' + yield '', static_output return - # Defining some variables - cumulative_reply = '' prompt = generate_chat_prompt('', state, impersonate=True) stopping_strings = get_stopping_strings(state) - yield text + '...' - cumulative_reply = text - for i in range(state['chat_generation_attempts']): - reply = None - for reply in generate_reply(prompt + cumulative_reply, state, stopping_strings=stopping_strings, is_chat=True): - reply = cumulative_reply + reply - yield reply.lstrip(' ') - if shared.stop_everything: - return - - if reply in [None, cumulative_reply]: - break - else: - cumulative_reply = reply - - yield cumulative_reply.lstrip(' ') + yield text + '...', static_output + reply = None + for reply in generate_reply(prompt + text, state, stopping_strings=stopping_strings, is_chat=True): + yield (text + reply).lstrip(' '), static_output + if shared.stop_everything: + return def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_message=True): @@ -290,16 +298,33 @@ def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_ yield history -# Same as above but returns HTML for the UI -def generate_chat_reply_wrapper(text, start_with, state, regenerate=False, _continue=False): - if start_with != '' and not _continue: +def character_is_loaded(state, raise_exception=False): + if state['mode'] in ['chat', 'chat-instruct'] and state['name2'] == '': + logger.error('It looks like no character is loaded. Please load one under Parameters > Character.') + if raise_exception: + raise ValueError + + return False + else: + return True + + +def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): + ''' + Same as above but returns HTML for the UI + ''' + + if not character_is_loaded(state): + return + + if state['start_with'] != '' and not _continue: if regenerate: text, state['history'] = remove_last_message(state['history']) regenerate = False _continue = True send_dummy_message(text, state) - send_dummy_reply(start_with, state) + send_dummy_reply(state['start_with'], state) for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True)): yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style']), history @@ -312,29 +337,32 @@ def remove_last_message(history): else: last = ['', ''] - return last[0], history + return html.unescape(last[0]), history def send_last_reply_to_input(history): - if len(history['internal']) > 0: - return history['internal'][-1][1] + if len(history['visible']) > 0: + return html.unescape(history['visible'][-1][1]) else: return '' def replace_last_reply(text, state): history = state['history'] - if len(history['visible']) > 0: - history['visible'][-1][1] = text - history['internal'][-1][1] = apply_extensions('input', text, state) + + if len(text.strip()) == 0: + return history + elif len(history['visible']) > 0: + history['visible'][-1][1] = html.escape(text) + history['internal'][-1][1] = apply_extensions('input', text, state, is_chat=True) return history def send_dummy_message(text, state): history = state['history'] - history['visible'].append([text, '']) - history['internal'].append([apply_extensions('input', text, state), '']) + history['visible'].append([html.escape(text), '']) + history['internal'].append([apply_extensions('input', text, state, is_chat=True), '']) return history @@ -344,23 +372,8 @@ def send_dummy_reply(text, state): history['visible'].append(['', '']) history['internal'].append(['', '']) - history['visible'][-1][1] = text - history['internal'][-1][1] = apply_extensions('input', text, state) - return history - - -def clear_chat_log(state): - greeting = state['greeting'] - mode = state['mode'] - history = state['history'] - - history['visible'] = [] - history['internal'] = [] - if mode != 'instruct': - if greeting != '': - history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] - history['visible'] += [['', apply_extensions('output', greeting, state)]] - + history['visible'][-1][1] = html.escape(text) + history['internal'][-1][1] = apply_extensions('input', text, state, is_chat=True) return history @@ -368,53 +381,143 @@ def redraw_html(history, name1, name2, mode, style, reset_cache=False): return chat_html_wrapper(history, name1, name2, mode, style, reset_cache=reset_cache) -def save_history(history, path=None): - p = path or Path('logs/exported_history.json') - with open(p, 'w', encoding='utf-8') as f: - f.write(json.dumps(history, indent=4)) +def start_new_chat(state): + mode = state['mode'] + history = {'internal': [], 'visible': []} + + if mode != 'instruct': + greeting = replace_character_names(state['greeting'], state['name1'], state['name2']) + if greeting != '': + history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] + history['visible'] += [['', apply_extensions('output', greeting, state, is_chat=True)]] + + unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S') + save_history(history, unique_id, state['character_menu'], state['mode']) + + return history + + +def get_history_file_path(unique_id, character, mode): + if mode == 'instruct': + p = Path(f'logs/instruct/{unique_id}.json') + else: + p = Path(f'logs/chat/{character}/{unique_id}.json') return p -def load_history(file, history): +def save_history(history, unique_id, character, mode): + if shared.args.multi_user: + return + + p = get_history_file_path(unique_id, character, mode) + if not p.parent.is_dir(): + p.parent.mkdir(parents=True) + + with open(p, 'w', encoding='utf-8') as f: + f.write(json.dumps(history, indent=4)) + + +def rename_history(old_id, new_id, character, mode): + if shared.args.multi_user: + return + + old_p = get_history_file_path(old_id, character, mode) + new_p = get_history_file_path(new_id, character, mode) + if new_p.parent != old_p.parent: + logger.error(f"The following path is not allowed: {new_p}.") + elif new_p == old_p: + logger.info("The provided path is identical to the old one.") + else: + logger.info(f"Renaming {old_p} to {new_p}") + old_p.rename(new_p) + + +def find_all_histories(state): + if shared.args.multi_user: + return [''] + + if state['mode'] == 'instruct': + paths = Path('logs/instruct').glob('*.json') + else: + character = state['character_menu'] + + # Handle obsolete filenames and paths + old_p = Path(f'logs/{character}_persistent.json') + new_p = Path(f'logs/persistent_{character}.json') + if old_p.exists(): + logger.warning(f"Renaming {old_p} to {new_p}") + old_p.rename(new_p) + if new_p.exists(): + unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S') + p = get_history_file_path(unique_id, character, state['mode']) + logger.warning(f"Moving {new_p} to {p}") + p.parent.mkdir(exist_ok=True) + new_p.rename(p) + + paths = Path(f'logs/chat/{character}').glob('*.json') + + histories = sorted(paths, key=lambda x: x.stat().st_mtime, reverse=True) + histories = [path.stem for path in histories] + + return histories + + +def load_latest_history(state): + ''' + Loads the latest history for the given character in chat or chat-instruct + mode, or the latest instruct history for instruct mode. + ''' + + if shared.args.multi_user: + return start_new_chat(state) + + histories = find_all_histories(state) + + if len(histories) > 0: + unique_id = Path(histories[0]).stem + history = load_history(unique_id, state['character_menu'], state['mode']) + else: + history = start_new_chat(state) + + return history + + +def load_history(unique_id, character, mode): + p = get_history_file_path(unique_id, character, mode) + + f = json.loads(open(p, 'rb').read()) + if 'internal' in f and 'visible' in f: + history = f + else: + history = { + 'internal': f['data'], + 'visible': f['data_visible'] + } + + return history + + +def load_history_json(file, history): try: file = file.decode('utf-8') - j = json.loads(file) - if 'internal' in j and 'visible' in j: - return j + f = json.loads(file) + if 'internal' in f and 'visible' in f: + history = f else: - return history + history = { + 'internal': f['data'], + 'visible': f['data_visible'] + } + + return history except: return history -def save_persistent_history(history, character, mode): - if mode in ['chat', 'chat-instruct'] and character not in ['', 'None', None] and not shared.args.multi_user: - save_history(history, path=Path(f'logs/{character}_persistent.json')) - - -def load_persistent_history(state): - if state['mode'] == 'instruct': - return state['history'] - - character = state['character_menu'] - greeting = state['greeting'] - p = Path(f'logs/{character}_persistent.json') - if not shared.args.multi_user and character not in ['None', '', None] and p.exists(): - f = json.loads(open(p, 'rb').read()) - if 'internal' in f and 'visible' in f: - history = f - else: - history = {'internal': [], 'visible': []} - history['internal'] = f['data'] - history['visible'] = f['data_visible'] - else: - history = {'internal': [], 'visible': []} - if greeting != "": - history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] - history['visible'] += [['', apply_extensions('output', greeting, state)]] - - return history +def delete_history(unique_id, character, mode): + p = get_history_file_path(unique_id, character, mode) + delete_file(p) def replace_character_names(text, name1, name2): @@ -422,18 +525,6 @@ def replace_character_names(text, name1, name2): return text.replace('', name1).replace('', name2) -def build_pygmalion_style_context(data): - context = "" - if 'char_persona' in data and data['char_persona'] != '': - context += f"{data['char_name']}'s Persona: {data['char_persona']}\n" - - if 'world_scenario' in data and data['world_scenario'] != '': - context += f"Scenario: {data['world_scenario']}\n" - - context = f"{context.strip()}\n\n" - return context - - def generate_pfp_cache(character): cache_folder = Path("cache") if not cache_folder.exists(): @@ -453,59 +544,55 @@ def load_character(character, name1, name2, instruct=False): greeting_field = 'greeting' picture = None - # Deleting the profile picture cache, if any - if Path("cache/pfp_character.png").exists(): + if instruct: + name1 = name2 = '' + folder = 'instruction-templates' + else: + folder = 'characters' + + filepath = None + for extension in ["yml", "yaml", "json"]: + filepath = Path(f'{folder}/{character}.{extension}') + if filepath.exists(): + break + + if filepath is None or not filepath.exists(): + logger.error(f"Could not find the character \"{character}\" inside {folder}/. No character has been loaded.") + raise ValueError + + file_contents = open(filepath, 'r', encoding='utf-8').read() + data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents) + + if Path("cache/pfp_character.png").exists() and not instruct: Path("cache/pfp_character.png").unlink() - if character not in ['None', '', None]: - folder = 'characters' if not instruct else 'characters/instruction-following' - picture = generate_pfp_cache(character) - for extension in ["yml", "yaml", "json"]: - filepath = Path(f'{folder}/{character}.{extension}') - if filepath.exists(): - break + picture = generate_pfp_cache(character) - file_contents = open(filepath, 'r', encoding='utf-8').read() - data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents) + # Finding the bot's name + for k in ['name', 'bot', '<|bot|>', 'char_name']: + if k in data and data[k] != '': + name2 = data[k] + break - # Finding the bot's name - for k in ['name', 'bot', '<|bot|>', 'char_name']: - if k in data and data[k] != '': - name2 = data[k] - break + # Find the user name (if any) + for k in ['your_name', 'user', '<|user|>']: + if k in data and data[k] != '': + name1 = data[k] + break - # Find the user name (if any) - for k in ['your_name', 'user', '<|user|>']: - if k in data and data[k] != '': - name1 = data[k] - break + if 'context' in data: + context = data['context'] + if not instruct: + context = context.strip() + '\n' + elif "char_persona" in data: + context = build_pygmalion_style_context(data) + greeting_field = 'char_greeting' - for field in ['context', 'greeting', 'example_dialogue', 'char_persona', 'char_greeting', 'world_scenario']: - if field in data: - data[field] = replace_character_names(data[field], name1, name2) + if greeting_field in data: + greeting = data[greeting_field] - if 'context' in data: - context = data['context'] - if not instruct: - context = context.strip() + '\n' - elif "char_persona" in data: - context = build_pygmalion_style_context(data) - greeting_field = 'char_greeting' - - if 'example_dialogue' in data: - context += f"{data['example_dialogue'].strip()}\n" - - if greeting_field in data: - greeting = data[greeting_field] - - if 'turn_template' in data: - turn_template = data['turn_template'] - - else: - context = shared.settings['context'] - name2 = shared.settings['name2'] - greeting = shared.settings['greeting'] - turn_template = shared.settings['turn_template'] + if 'turn_template' in data: + turn_template = data['turn_template'] return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n") @@ -515,40 +602,67 @@ def load_character_memoized(character, name1, name2, instruct=False): return load_character(character, name1, name2, instruct=instruct) -def upload_character(json_file, img, tavern=False): - json_file = json_file if type(json_file) == str else json_file.decode('utf-8') - data = json.loads(json_file) - outfile_name = data["char_name"] +def upload_character(file, img, tavern=False): + decoded_file = file if isinstance(file, str) else file.decode('utf-8') + try: + data = json.loads(decoded_file) + except: + data = yaml.safe_load(decoded_file) + + if 'char_name' in data: + name = data['char_name'] + greeting = data['char_greeting'] + context = build_pygmalion_style_context(data) + yaml_data = generate_character_yaml(name, greeting, context) + else: + name = data['name'] + yaml_data = generate_character_yaml(data['name'], data['greeting'], data['context']) + + outfile_name = name i = 1 - while Path(f'characters/{outfile_name}.json').exists(): - outfile_name = f'{data["char_name"]}_{i:03d}' + while Path(f'characters/{outfile_name}.yaml').exists(): + outfile_name = f'{name}_{i:03d}' i += 1 - if tavern: - outfile_name = f'TavernAI-{outfile_name}' - - with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f: - f.write(json_file) + with open(Path(f'characters/{outfile_name}.yaml'), 'w', encoding='utf-8') as f: + f.write(yaml_data) if img is not None: img.save(Path(f'characters/{outfile_name}.png')) - logger.info(f'New character saved to "characters/{outfile_name}.json".') + logger.info(f'New character saved to "characters/{outfile_name}.yaml".') return gr.update(value=outfile_name, choices=get_available_characters()) +def build_pygmalion_style_context(data): + context = "" + if 'char_persona' in data and data['char_persona'] != '': + context += f"{data['char_name']}'s Persona: {data['char_persona']}\n" + + if 'world_scenario' in data and data['world_scenario'] != '': + context += f"Scenario: {data['world_scenario']}\n" + + if 'example_dialogue' in data and data['example_dialogue'] != '': + context += f"{data['example_dialogue'].strip()}\n" + + context = f"{context.strip()}\n" + return context + + def upload_tavern_character(img, _json): - _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']} + _json = {'char_name': _json['name'], 'char_persona': _json['description'], 'char_greeting': _json['first_mes'], 'example_dialogue': _json['mes_example'], 'world_scenario': _json['scenario']} return upload_character(json.dumps(_json), img, tavern=True) def check_tavern_character(img): if "chara" not in img.info: return "Not a TavernAI card", None, None, gr.update(interactive=False) - decoded_string = base64.b64decode(img.info['chara']) + + decoded_string = base64.b64decode(img.info['chara']).replace(b'\\r\\n', b'\\n') _json = json.loads(decoded_string) if "data" in _json: _json = _json["data"] + return _json['name'], _json['description'], _json, gr.update(interactive=True) @@ -574,7 +688,7 @@ def generate_character_yaml(name, greeting, context): } data = {k: v for k, v in data.items() if v} # Strip falsy - return yaml.dump(data, sort_keys=False) + return yaml.dump(data, sort_keys=False, width=float("inf")) def generate_instruction_template_yaml(user, bot, context, turn_template): @@ -586,7 +700,7 @@ def generate_instruction_template_yaml(user, bot, context, turn_template): } data = {k: v for k, v in data.items() if v} # Strip falsy - return yaml.dump(data, sort_keys=False) + return yaml.dump(data, sort_keys=False, width=float("inf")) def save_character(name, greeting, context, picture, filename): diff --git a/modules/ctransformers_model.py b/modules/ctransformers_model.py new file mode 100644 index 0000000..70ce92f --- /dev/null +++ b/modules/ctransformers_model.py @@ -0,0 +1,79 @@ +from ctransformers import AutoConfig, AutoModelForCausalLM + +from modules import shared +from modules.callbacks import Iteratorize +from modules.logging_colors import logger + + +class CtransformersModel: + def __init__(self): + pass + + @classmethod + def from_pretrained(cls, path): + result = cls() + + config = AutoConfig.from_pretrained( + str(path), + threads=shared.args.threads if shared.args.threads != 0 else -1, + gpu_layers=shared.args.n_gpu_layers, + batch_size=shared.args.n_batch, + context_length=shared.args.n_ctx, + stream=True, + mmap=not shared.args.no_mmap, + mlock=shared.args.mlock + ) + + result.model = AutoModelForCausalLM.from_pretrained( + str(result.model_dir(path) if result.model_type_is_auto() else path), + model_type=(None if result.model_type_is_auto() else shared.args.model_type), + config=config + ) + + logger.info(f'Using ctransformers model_type: {result.model.model_type} for {result.model.model_path}') + return result, result + + def model_type_is_auto(self): + return shared.args.model_type is None or shared.args.model_type == "Auto" or shared.args.model_type == "None" + + def model_dir(self, path): + if path.is_file(): + return path.parent + + return path + + def encode(self, string, **kwargs): + return self.model.tokenize(string) + + def decode(self, ids): + return self.model.detokenize(ids) + + def generate(self, prompt, state, callback=None): + prompt = prompt if type(prompt) is str else prompt.decode() + # ctransformers uses -1 for random seed + generator = self.model( + prompt=prompt, + max_new_tokens=state['max_new_tokens'], + temperature=state['temperature'], + top_p=state['top_p'], + top_k=state['top_k'], + repetition_penalty=state['repetition_penalty'], + last_n_tokens=state['repetition_penalty_range'], + seed=int(state['seed']) + ) + + output = "" + for token in generator: + if callback: + callback(token) + + output += token + + return output + + def generate_with_streaming(self, *args, **kwargs): + with Iteratorize(self.generate, args, kwargs, callback=None) as generator: + reply = '' + for token in generator: + reply += token + yield reply diff --git a/modules/deepspeed_parameters.py b/modules/deepspeed_parameters.py index 9116f57..f170a38 100644 --- a/modules/deepspeed_parameters.py +++ b/modules/deepspeed_parameters.py @@ -1,6 +1,6 @@ def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir): ''' - DeepSpeed configration + DeepSpeed configuration https://huggingface.co/docs/transformers/main_classes/deepspeed ''' diff --git a/modules/evaluate.py b/modules/evaluate.py index d94863d..8044e20 100644 --- a/modules/evaluate.py +++ b/modules/evaluate.py @@ -8,10 +8,7 @@ from tqdm import tqdm from modules import shared from modules.models import load_model, unload_model -from modules.models_settings import ( - get_model_settings_from_yamls, - update_model_parameters -) +from modules.models_settings import get_model_metadata, update_model_parameters from modules.text_generation import encode @@ -69,8 +66,8 @@ def calculate_perplexity(models, input_dataset, stride, _max_length): if model != 'current model': try: yield cumulative_log + f"Loading {model}...\n\n" - model_settings = get_model_settings_from_yamls(model) - shared.settings.update(model_settings) # hijacking the interface defaults + model_settings = get_model_metadata(model) + shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) # hijacking the interface defaults update_model_parameters(model_settings) # hijacking the command-line arguments shared.model_name = model unload_model() diff --git a/modules/exllama.py b/modules/exllama.py index ecfb10a..cb92344 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -1,9 +1,12 @@ from pathlib import Path +import torch +import torch.nn.functional as F from torch import version as torch_version from modules import shared from modules.logging_colors import logger +from modules.models import clear_torch_cache from modules.text_generation import get_max_prompt_length try: @@ -11,7 +14,7 @@ try: from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig from exllama.tokenizer import ExLlamaTokenizer except: - logger.warning('Exllama module failed to load. Will attempt to load from repositories.') + logger.warning('exllama module failed to import. Will attempt to import from repositories/.') try: from modules.relative_imports import RelativeImport @@ -20,7 +23,10 @@ except: from model import ExLlama, ExLlamaCache, ExLlamaConfig from tokenizer import ExLlamaTokenizer except: - logger.error("Could not find repositories/exllama/. Make sure that exllama is cloned inside repositories/ and is up to date.") + logger.error( + "Could not find repositories/exllama. Please ensure that exllama" + " (https://github.com/turboderp/exllama) is cloned inside repositories/ and is up to date." + ) raise @@ -54,9 +60,11 @@ class ExllamaModel: config.set_auto_map(shared.args.gpu_split) config.gpu_peer_fix = True - if shared.args.alpha_value: + if shared.args.alpha_value > 1 and shared.args.rope_freq_base == 0: config.alpha_value = shared.args.alpha_value config.calculate_rotary_embedding_base() + elif shared.args.rope_freq_base > 0: + config.rotary_embedding_base = shared.args.rope_freq_base if torch_version.hip: config.rmsnorm_no_half2 = True @@ -77,7 +85,38 @@ class ExllamaModel: result.generator = generator return result, result + def encode(self, string, **kwargs): + return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True) + + def decode(self, ids, **kwargs): + if isinstance(ids, list): + ids = torch.tensor([ids]) + elif isinstance(ids, torch.Tensor) and ids.numel() == 1: + ids = ids.view(1, -1) + + return self.tokenizer.decode(ids)[0] + + def get_logits(self, token_ids, **kwargs): + self.cache.current_seq_len = 0 + self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True) + return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu() + def generate_with_streaming(self, prompt, state): + + # The cache batch size must be 2 for CFG and 1 otherwise + if state['guidance_scale'] == 1: + if self.cache.batch_size == 2: + del self.cache + clear_torch_cache() + self.cache = ExLlamaCache(self.model) + self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache) + else: + if self.cache.batch_size == 1: + del self.cache + clear_torch_cache() + self.cache = ExLlamaCache(self.model, batch_size=2) + self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache) + self.generator.settings.temperature = state['temperature'] self.generator.settings.top_p = state['top_p'] self.generator.settings.top_k = state['top_k'] @@ -89,27 +128,87 @@ class ExllamaModel: else: self.generator.disallow_tokens(None) - self.generator.end_beam_search() + if state['custom_token_bans']: + to_ban = [int(x) for x in state['custom_token_bans'].split(',')] + if len(to_ban) > 0: + self.generator.disallow_tokens(to_ban) - # Tokenizing the input - ids = self.generator.tokenizer.encode(prompt) - ids = ids[:, -get_max_prompt_length(state):] + # Case 1: no CFG + if state['guidance_scale'] == 1: + self.generator.end_beam_search() - self.generator.gen_begin_reuse(ids) - initial_len = self.generator.sequence[0].shape[0] - has_leading_space = False - for i in range(state['max_new_tokens']): - token = self.generator.gen_single_token() - if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): - has_leading_space = True + # Tokenizing the input + ids = self.generator.tokenizer.encode(prompt, max_seq_len=self.model.config.max_seq_len) + if state['add_bos_token']: + ids = torch.cat( + [torch.tensor([[self.tokenizer.bos_token_id]]).to(ids.device), + ids], dim=1 + ).to(torch.int64) + ids = ids[:, -get_max_prompt_length(state):] + if state['auto_max_new_tokens']: + max_new_tokens = state['truncation_length'] - ids.shape[-1] + else: + max_new_tokens = state['max_new_tokens'] - decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:]) - if has_leading_space: - decoded_text = ' ' + decoded_text + self.generator.gen_begin_reuse(ids) + initial_len = self.generator.sequence[0].shape[0] + has_leading_space = False - yield decoded_text - if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything: - break + for i in range(max_new_tokens): + token = self.generator.gen_single_token() + if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): + has_leading_space = True + + decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:]) + if has_leading_space: + decoded_text = ' ' + decoded_text + + yield decoded_text + if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything: + break + + # Case 2: CFG + # Copied from https://github.com/turboderp/exllama/blob/master/example_cfg.py + else: + alpha = state['guidance_scale'] + prompts = [prompt, state['negative_prompt'] or ''] + + ids, mask = self.tokenizer.encode( + prompts, + return_mask=True, + max_seq_len=self.model.config.max_seq_len, + add_bos=state['add_bos_token'] + ) + if state['auto_max_new_tokens']: + max_new_tokens = state['truncation_length'] - ids[0].shape[-1] + else: + max_new_tokens = state['max_new_tokens'] + + self.generator.gen_begin(ids, mask=mask) + initial_len = self.generator.sequence[0].shape[0] + has_leading_space = False + + for i in range(max_new_tokens): + logits = self.model.forward(self.generator.sequence[:, -1:], self.cache, input_mask=mask) + self.generator.apply_rep_penalty(logits) + + logits = F.log_softmax(logits, dim=-1) + logits_mixed = alpha * logits[0] + (1 - alpha) * logits[1] + + token, _ = self.generator.sample_current(logits_mixed) + if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): + has_leading_space = True + + decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:]) + if has_leading_space: + decoded_text = ' ' + decoded_text + + yield decoded_text + if token.item() == self.tokenizer.eos_token_id or shared.stop_everything: + break + + batch_token = token.repeat(2, 1) + self.generator.gen_accept_token(batch_token) def generate(self, prompt, state): output = '' @@ -117,9 +216,3 @@ class ExllamaModel: pass return output - - def encode(self, string, **kwargs): - return self.tokenizer.encode(string) - - def decode(self, string, **kwargs): - return self.tokenizer.decode(string)[0] diff --git a/modules/exllama_hf.py b/modules/exllama_hf.py index a25c3f4..3245ac8 100644 --- a/modules/exllama_hf.py +++ b/modules/exllama_hf.py @@ -32,6 +32,13 @@ class ExllamaHF(PreTrainedModel): self.generation_config = GenerationConfig() self.lora = None + self.ex_cache = ExLlamaCache(self.ex_model) + self.past_seq = None + + if shared.args.cfg_cache: + self.ex_cache_negative = ExLlamaCache(self.ex_model) + self.past_seq_negative = None + def _validate_model_class(self): pass @@ -46,17 +53,62 @@ class ExllamaHF(PreTrainedModel): return torch.device(0) def __call__(self, *args, **kwargs): - # TODO: Some decoding methods (such as Contrastive Search) may not work at this time - assert len(args) == 0, 'no *args should be passed to forward' use_cache = kwargs.get('use_cache', True) labels = kwargs.get('labels', None) - seq = kwargs['input_ids'][0].tolist() - cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None - if cache is None: - cache = ExLlamaCache(self.ex_model) - self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora) + past_key_values = kwargs.get('past_key_values', None) - logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device) + if len(args) > 0: + if not shared.args.cfg_cache: + logger.error("Please enable the cfg-cache option to use CFG with ExLlama_HF.") + return + + input_ids = args[0] + is_negative = True + past_seq = self.past_seq_negative + ex_cache = self.ex_cache_negative + else: + input_ids = kwargs['input_ids'] + is_negative = False + past_seq = self.past_seq + ex_cache = self.ex_cache + + seq = input_ids[0].tolist() + if is_negative and past_key_values is not None: + seq = past_key_values + seq + + seq_tensor = torch.tensor(seq) + reset = True + + # Make the forward call + if labels is None: + if past_seq is not None: + min_length = min(past_seq.shape[0], seq_tensor.shape[0]) + indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) + if len(indices) > 0: + longest_prefix = indices[0].item() + else: + longest_prefix = min_length + + if longest_prefix > 0: + reset = False + ex_cache.current_seq_len = longest_prefix + if len(seq_tensor) - longest_prefix > 1: + self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora) + + if reset: + ex_cache.current_seq_len = 0 + if len(seq_tensor) > 1: + self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora) + + logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, lora=self.lora).to(input_ids.device) + else: + ex_cache.current_seq_len = 0 + logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, lora=self.lora) + + if is_negative: + self.past_seq_negative = seq_tensor + else: + self.past_seq = seq_tensor loss = None if labels is not None: @@ -71,7 +123,7 @@ class ExllamaHF(PreTrainedModel): shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) - return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None) + return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): @@ -98,9 +150,11 @@ class ExllamaHF(PreTrainedModel): config.set_auto_map(shared.args.gpu_split) config.gpu_peer_fix = True - if shared.args.alpha_value: + if shared.args.alpha_value > 1 and shared.args.rope_freq_base == 0: config.alpha_value = shared.args.alpha_value config.calculate_rotary_embedding_base() + elif shared.args.rope_freq_base > 0: + config.rotary_embedding_base = shared.args.rope_freq_base if torch.version.hip: config.rmsnorm_no_half2 = True diff --git a/modules/exllamav2.py b/modules/exllamav2.py new file mode 100644 index 0000000..be5f47e --- /dev/null +++ b/modules/exllamav2.py @@ -0,0 +1,132 @@ +import random +from pathlib import Path + +import torch +from exllamav2 import ( + ExLlamaV2, + ExLlamaV2Cache, + ExLlamaV2Config, + ExLlamaV2Tokenizer +) +from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler + +from modules import shared +from modules.logging_colors import logger +from modules.text_generation import get_max_prompt_length + +try: + import flash_attn +except ModuleNotFoundError: + logger.warning( + 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage ' + 'to be a lot higher than it could be.\n' + 'Try installing flash-attention following the instructions here: ' + 'https://github.com/Dao-AILab/flash-attention#installation-and-features' + ) + pass + + +class Exllamav2Model: + def __init__(self): + pass + + @classmethod + def from_pretrained(self, path_to_model): + + path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model) + + config = ExLlamaV2Config() + config.model_dir = str(path_to_model) + config.prepare() + + config.max_seq_len = shared.args.max_seq_len + config.scale_pos_emb = shared.args.compress_pos_emb + config.scale_alpha_value = shared.args.alpha_value + + model = ExLlamaV2(config) + + split = None + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + + model.load(split) + + tokenizer = ExLlamaV2Tokenizer(config) + cache = ExLlamaV2Cache(model) + generator = ExLlamaV2BaseGenerator(model, cache, tokenizer) + + result = self() + result.model = model + result.cache = cache + result.tokenizer = tokenizer + result.generator = generator + return result, result + + def encode(self, string, **kwargs): + return self.tokenizer.encode(string, add_bos=True) + + def decode(self, ids, **kwargs): + if isinstance(ids, list): + ids = torch.tensor([ids]) + elif isinstance(ids, torch.Tensor) and ids.numel() == 1: + ids = ids.view(1, -1) + + return self.tokenizer.decode(ids)[0] + + def get_logits(self, token_ids, **kwargs): + self.cache.current_seq_len = 0 + self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True) + return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu() + + def generate_with_streaming(self, prompt, state): + settings = ExLlamaV2Sampler.Settings() + settings.temperature = state['temperature'] + settings.top_k = state['top_k'] + settings.top_p = state['top_p'] + settings.token_repetition_penalty = state['repetition_penalty'] + settings.token_repetition_range = -1 if state['repetition_penalty_range'] <= 0 else state['repetition_penalty_range'] + if state['ban_eos_token']: + settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) + + if state['custom_token_bans']: + to_ban = [int(x) for x in state['custom_token_bans'].split(',')] + if len(to_ban) > 0: + settings.disallow_tokens(self.tokenizer, to_ban) + + ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token']) + ids = ids[:, -get_max_prompt_length(state):] + initial_len = ids.shape[-1] + + if state['auto_max_new_tokens']: + max_new_tokens = state['truncation_length'] - ids.shape[-1] + else: + max_new_tokens = state['max_new_tokens'] + + # _gen_begin_base + self.cache.current_seq_len = 0 + self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True) + + has_leading_space = False + for i in range(max_new_tokens): + logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None).float().cpu() + token, _ = ExLlamaV2Sampler.sample(logits, settings, ids, random.random()) + ids = torch.cat([ids, token], dim=1) + + if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): + has_leading_space = True + + decoded_text = self.tokenizer.decode(ids[:, initial_len:])[0] + if has_leading_space: + decoded_text = ' ' + decoded_text + + yield decoded_text + + if token.item() == self.tokenizer.eos_token_id or shared.stop_everything: + break + + def generate(self, prompt, state): + output = '' + for output in self.generate_with_streaming(prompt, state): + pass + + return output diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py new file mode 100644 index 0000000..6542ede --- /dev/null +++ b/modules/exllamav2_hf.py @@ -0,0 +1,148 @@ +import os +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import torch +from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config +from torch.nn import CrossEntropyLoss +from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast + +from modules import shared +from modules.logging_colors import logger + +try: + import flash_attn +except ModuleNotFoundError: + logger.warning( + 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage ' + 'to be a lot higher than it could be.\n' + 'Try installing flash-attention following the instructions here: ' + 'https://github.com/Dao-AILab/flash-attention#installation-and-features' + ) + pass + + +class Exllamav2HF(PreTrainedModel): + def __init__(self, config: ExLlamaV2Config): + super().__init__(PretrainedConfig()) + self.ex_config = config + self.ex_model = ExLlamaV2(config) + split = None + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + + self.ex_model.load(split) + + self.generation_config = GenerationConfig() + + self.ex_cache = ExLlamaV2Cache(self.ex_model) + self.past_seq = None + + if shared.args.cfg_cache: + self.ex_cache_negative = ExLlamaV2Cache(self.ex_model) + self.past_seq_negative = None + + def _validate_model_class(self): + pass + + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + pass + + def prepare_inputs_for_generation(self, input_ids, **kwargs): + return {'input_ids': input_ids, **kwargs} + + @property + def device(self) -> torch.device: + return torch.device(0) + + def __call__(self, *args, **kwargs): + use_cache = kwargs.get('use_cache', True) + labels = kwargs.get('labels', None) + past_key_values = kwargs.get('past_key_values', None) + + if len(args) > 0: + if not shared.args.cfg_cache: + logger.error("Please enable the cfg-cache option to use CFG with ExLlamav2_HF.") + return + + input_ids = args[0] + is_negative = True + past_seq = self.past_seq_negative + ex_cache = self.ex_cache_negative + else: + input_ids = kwargs['input_ids'] + is_negative = False + past_seq = self.past_seq + ex_cache = self.ex_cache + + seq = input_ids[0].tolist() + if is_negative and past_key_values is not None: + seq = past_key_values + seq + + seq_tensor = torch.tensor(seq) + reset = True + + # Make the forward call + if labels is None: + if past_seq is not None: + min_length = min(past_seq.shape[0], seq_tensor.shape[0]) + indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) + if len(indices) > 0: + longest_prefix = indices[0].item() + else: + longest_prefix = min_length + + if longest_prefix > 0: + reset = False + ex_cache.current_seq_len = longest_prefix + if len(seq_tensor) - longest_prefix > 1: + self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True) + + if reset: + ex_cache.current_seq_len = 0 + if len(seq_tensor) > 1: + self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True) + + logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache).to(input_ids.device) + else: + ex_cache.current_seq_len = 0 + logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False) + + if is_negative: + self.past_seq_negative = seq_tensor + else: + self.past_seq = seq_tensor + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, logits.shape[-1]) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): + assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported" + if isinstance(pretrained_model_name_or_path, str): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + + pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path) + + config = ExLlamaV2Config() + config.model_dir = str(pretrained_model_name_or_path) + config.prepare() + + config.max_seq_len = shared.args.max_seq_len + config.scale_pos_emb = shared.args.compress_pos_emb + config.scale_alpha_value = shared.args.alpha_value + + return Exllamav2HF(config) diff --git a/modules/extensions.py b/modules/extensions.py index 8705101..6c07250 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,13 +1,12 @@ import traceback from functools import partial +from inspect import signature import gradio as gr import extensions import modules.shared as shared from modules.logging_colors import logger -from inspect import signature - state = {} available_extensions = [] @@ -28,6 +27,7 @@ def apply_settings(extension, name): def load_extensions(): global state, setup_called + state = {} for i, name in enumerate(shared.args.extensions): if name in available_extensions: if name != 'api': @@ -54,27 +54,41 @@ def iterator(): # Extension functions that map string -> string -def _apply_string_extensions(function_name, text, state): +def _apply_string_extensions(function_name, text, state, is_chat=False): for extension, _ in iterator(): if hasattr(extension, function_name): func = getattr(extension, function_name) - if len(signature(func).parameters) == 2: - text = func(text, state) + + # Handle old extensions without the 'state' arg or + # the 'is_chat' kwarg + count = 0 + has_chat = False + for k in signature(func).parameters: + if k == 'is_chat': + has_chat = True + else: + count += 1 + + if count == 2: + args = [text, state] else: - text = func(text) + args = [text] + + if has_chat: + kwargs = {'is_chat': is_chat} + else: + kwargs = {} + + text = func(*args, **kwargs) return text -# Input hijack of extensions -def _apply_input_hijack(text, visible_text): +# Extension functions that map string -> string +def _apply_chat_input_extensions(text, visible_text, state): for extension, _ in iterator(): - if hasattr(extension, 'input_hijack') and extension.input_hijack['state']: - extension.input_hijack['state'] = False - if callable(extension.input_hijack['value']): - text, visible_text = extension.input_hijack['value'](text, visible_text) - else: - text, visible_text = extension.input_hijack['value'] + if hasattr(extension, 'chat_input_modifier'): + text, visible_text = extension.chat_input_modifier(text, visible_text, state) return text, visible_text @@ -106,15 +120,27 @@ def _apply_history_modifier_extensions(history): return history -# Extension functions that override the default tokenizer output - currently only the first one will work +# Extension functions that override the default tokenizer output - The order of execution is not defined def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds): for extension, _ in iterator(): if hasattr(extension, function_name): - return getattr(extension, function_name)(state, prompt, input_ids, input_embeds) + prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds) return prompt, input_ids, input_embeds +# Allow extensions to add their own logits processors to the stack being run. +# Each extension would call `processor_list.append({their LogitsProcessor}())`. +def _apply_logits_processor_extensions(function_name, processor_list, input_ids): + for extension, _ in iterator(): + if hasattr(extension, function_name): + result = getattr(extension, function_name)(processor_list, input_ids) + if type(result) is list: + processor_list = result + + return processor_list + + # Get prompt length in tokens after applying extension functions which override the default tokenizer output # currently only the first one will work def _apply_custom_tokenized_length(prompt): @@ -162,9 +188,7 @@ def create_extensions_block(): if len(to_display) > 0: with gr.Column(elem_id="extensions"): for row in to_display: - extension, name = row - display_name = getattr(extension, 'params', {}).get('display_name', name) - gr.Markdown(f"\n### {display_name}") + extension, _ = row extension.ui() @@ -179,11 +203,12 @@ def create_extensions_tabs(): EXTENSION_MAP = { "input": partial(_apply_string_extensions, "input_modifier"), "output": partial(_apply_string_extensions, "output_modifier"), + "chat_input": _apply_chat_input_extensions, "state": _apply_state_modifier_extensions, "history": _apply_history_modifier_extensions, "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"), "tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"), - "input_hijack": _apply_input_hijack, + 'logits_processor': partial(_apply_logits_processor_extensions, 'logits_processor_modifier'), "custom_generate_chat_prompt": _apply_custom_generate_chat_prompt, "custom_generate_reply": _apply_custom_generate_reply, "tokenized_length": _apply_custom_tokenized_length, diff --git a/modules/html_generator.py b/modules/html_generator.py index 4910ffe..2da2793 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -1,3 +1,4 @@ +import html import os import re import time @@ -6,6 +7,7 @@ from pathlib import Path import markdown from PIL import Image, ImageOps +from modules.logging_colors import logger from modules.utils import get_available_chat_styles # This is to store the paths to the thumbnails of the profile pictures @@ -23,6 +25,16 @@ chat_styles = {} for k in get_available_chat_styles(): chat_styles[k] = open(Path(f'css/chat_style-{k}.css'), 'r').read() +# Handle styles that derive from other styles +for k in chat_styles: + lines = chat_styles[k].split('\n') + input_string = lines[0] + match = re.search(r'chat_style-([a-z\-]*)\.css', input_string) + + if match: + style = match.group(1) + chat_styles[k] = chat_styles.get(style, '') + '\n\n' + '\n'.join(lines[1:]) + def fix_newlines(string): string = string.replace('\n', '\n\n') @@ -38,6 +50,7 @@ def replace_blockquote(m): def convert_to_markdown(string): # Blockquote + string = re.sub(r'(^|[\n])>', r'\1>', string) pattern = re.compile(r'\\begin{blockquote}(.*?)\\end{blockquote}', re.DOTALL) string = pattern.sub(replace_blockquote, string) @@ -58,11 +71,32 @@ def convert_to_markdown(string): else: result += '\n\n' + result = result.strip() if is_code: - result = result + '```' # Unfinished code block + result += '\n```' # Unfinished code block - string = result.strip() - return markdown.markdown(string, extensions=['fenced_code', 'tables']) + # Unfinished list, like "\n1.". A |delete| string is added and then + # removed to force a
    or
      to be generated instead of a

      . + if re.search(r'(\n\d+\.?|\n\*\s*)$', result): + delete_str = '|delete|' + + if re.search(r'(\d+\.?)$', result) and not result.endswith('.'): + result += '.' + + result = re.sub(r'(\n\d+\.?|\n\*\s*)$', r'\g<1> ' + delete_str, result) + + html_output = markdown.markdown(result, extensions=['fenced_code', 'tables']) + pos = html_output.rfind(delete_str) + if pos > -1: + html_output = html_output[:pos] + html_output[pos + len(delete_str):] + else: + html_output = markdown.markdown(result, extensions=['fenced_code', 'tables']) + + # Unescape code blocks + pattern = re.compile(r']*>(.*?)', re.DOTALL) + html_output = pattern.sub(lambda x: html.unescape(x.group()), html_output) + + return html_output def generate_basic_html(string): @@ -81,7 +115,7 @@ def process_post(post, c): src = re.sub('>', '>', src) src = re.sub('(>>[0-9]*)', '\\1', src) src = re.sub('\n', '
      \n', src) - src = f'

      {src}\n' + src = f'
      {src}\n' src = f'Anonymous No.{number}\n{src}' return src @@ -102,6 +136,7 @@ def generate_4chan_html(f): post = line else: post += line + if post != '': src = process_post(post, c) posts.append(src) @@ -116,13 +151,14 @@ def generate_4chan_html(f): output += f'
      ' for post in posts: output += post + output += '
      ' output = output.split('\n') for i in range(len(output)): output[i] = re.sub(r'^(>(.*?)(
      |))', r'\1', output[i]) - output[i] = re.sub(r'^
      (>(.*?)(
      |))', r'
      \1', output[i]) - output = '\n'.join(output) + output[i] = re.sub(r'^
      (>(.*?)(
      |))', r'
      \1', output[i]) + output = '\n'.join(output) return output @@ -142,7 +178,13 @@ def get_image_cache(path): mtime = os.stat(path).st_mtime if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache): img = make_thumbnail(Image.open(path)) - output_file = Path(f'cache/{path.name}_cache.png') + + old_p = Path(f'cache/{path.name}_cache.png') + p = Path(f'cache/cache_{path.name}.png') + if old_p.exists(): + old_p.rename(p) + + output_file = p img.convert('RGB').save(output_file, format='PNG') image_cache[path] = [mtime, output_file.as_posix()] @@ -150,10 +192,21 @@ def get_image_cache(path): def generate_instruct_html(history): - output = f'
      ' - for i, _row in enumerate(history[::-1]): + output = f'
      ' + for i, _row in enumerate(history): row = [convert_to_markdown(entry) for entry in _row] + if row[0]: # don't display empty user messages + output += f""" +
      +
      +
      + {row[0]} +
      +
      +
      + """ + output += f"""
      @@ -164,34 +217,38 @@ def generate_instruct_html(history):
      """ - if len(row[0]) == 0: # don't display empty user messages - continue - - output += f""" -
      -
      -
      - {row[0]} -
      -
      -
      - """ - - output += "
      " + output += "
      " return output def generate_cai_chat_html(history, name1, name2, style, reset_cache=False): - output = f'
      ' + output = f'
      ' # We use ?name2 and ?time.time() to force the browser to reset caches img_bot = f'' if Path("cache/pfp_character.png").exists() else '' img_me = f'' if Path("cache/pfp_me.png").exists() else '' - for i, _row in enumerate(history[::-1]): + for i, _row in enumerate(history): row = [convert_to_markdown(entry) for entry in _row] + if row[0]: # don't display empty user messages + output += f""" +
      +
      + {img_me} +
      +
      +
      + {name1} +
      +
      + {row[0]} +
      +
      +
      + """ + output += f"""
      @@ -208,49 +265,18 @@ def generate_cai_chat_html(history, name1, name2, style, reset_cache=False):
      """ - if len(row[0]) == 0: # don't display empty user messages - continue - - output += f""" -
      -
      - {img_me} -
      -
      -
      - {name1} -
      -
      - {row[0]} -
      -
      -
      - """ - - output += "
      " + output += "
      " return output def generate_chat_html(history, name1, name2, reset_cache=False): - output = f'
      ' + output = f'
      ' - for i, _row in enumerate(history[::-1]): + for i, _row in enumerate(history): row = [convert_to_markdown(entry) for entry in _row] - output += f""" -
      -
      -
      - {row[1]} -
      -
      -
      - """ - - if len(row[0]) == 0: # don't display empty user messages - continue - - output += f""" + if row[0]: # don't display empty user messages + output += f"""
      @@ -260,7 +286,17 @@ def generate_chat_html(history, name1, name2, reset_cache=False):
      """ - output += "
      " + output += f""" +
      +
      +
      + {row[1]} +
      +
      +
      + """ + + output += "
      " return output diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py new file mode 100644 index 0000000..3cb5df1 --- /dev/null +++ b/modules/llamacpp_hf.py @@ -0,0 +1,211 @@ +import os +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import torch +from torch.nn import CrossEntropyLoss +from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast + +from modules import RoPE, shared +from modules.logging_colors import logger + +import llama_cpp + +if torch.cuda.is_available() and not torch.version.hip: + try: + import llama_cpp_cuda + except: + llama_cpp_cuda = None +else: + llama_cpp_cuda = None + + +def llama_cpp_lib(): + if shared.args.cpu or llama_cpp_cuda is None: + return llama_cpp + else: + return llama_cpp_cuda + + +class LlamacppHF(PreTrainedModel): + def __init__(self, model, path): + super().__init__(PretrainedConfig()) + self.model = model + self.generation_config = GenerationConfig() + + self.past_seq = None + self.llamacpp_cache = { + 'n_tokens': self.model.n_tokens, + 'input_ids': self.model.input_ids, + 'scores': self.model.scores, + 'ctx': self.model.ctx + } + + if shared.args.cfg_cache: + self.past_seq_negative = None + self.llamacpp_cache_negative = { + 'n_tokens': self.model.n_tokens, + 'input_ids': self.model.input_ids.copy(), + 'scores': self.model.scores.copy(), + 'ctx': llama_cpp_lib().llama_new_context_with_model(model.model, model.params) + } + + def _validate_model_class(self): + pass + + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + pass + + def prepare_inputs_for_generation(self, input_ids, **kwargs): + return {'input_ids': input_ids, **kwargs} + + def save_cache(self): + self.llamacpp_cache.update({ + 'n_tokens': self.model.n_tokens, + 'input_ids': self.model.input_ids, + 'scores': self.model.scores, + 'ctx': self.model.ctx + }) + + def save_negative_cache(self): + self.llamacpp_cache_negative.update({ + 'n_tokens': self.model.n_tokens, + 'input_ids': self.model.input_ids, + 'scores': self.model.scores, + 'ctx': self.model.ctx + }) + + def load_cache(self): + self.model.n_tokens = self.llamacpp_cache['n_tokens'] + self.model.input_ids = self.llamacpp_cache['input_ids'] + self.model.scores = self.llamacpp_cache['scores'] + self.model.ctx = self.llamacpp_cache['ctx'] + + def load_negative_cache(self): + self.model.n_tokens = self.llamacpp_cache_negative['n_tokens'] + self.model.input_ids = self.llamacpp_cache_negative['input_ids'] + self.model.scores = self.llamacpp_cache_negative['scores'] + self.model.ctx = self.llamacpp_cache_negative['ctx'] + + @property + def device(self) -> torch.device: + return torch.device(0) + + def __call__(self, *args, **kwargs): + use_cache = kwargs.get('use_cache', True) + labels = kwargs.get('labels', None) + past_key_values = kwargs.get('past_key_values', None) + + if len(args) > 0: + if not shared.args.cfg_cache: + logger.error("Please enable the cfg-cache option to use CFG with llamacpp_HF.") + return + + input_ids = args[0] + is_negative = True + past_seq = self.past_seq_negative + self.load_negative_cache() + else: + input_ids = kwargs['input_ids'] + is_negative = False + past_seq = self.past_seq + self.load_cache() + + seq = input_ids[0].tolist() + if is_negative and past_key_values is not None: + seq = past_key_values + seq + + seq_tensor = torch.tensor(seq) + reset = True + + # Make the forward call. The prefix-match code has been adapted from + # https://github.com/abetlen/llama-cpp-python/commit/f4090a0bb2a2a25acfe28d31c82cc1aa273bedee + if labels is None: + if past_seq is not None: + min_length = min(past_seq.shape[0], seq_tensor.shape[0]) + indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) + if len(indices) > 0: + longest_prefix = indices[0].item() + else: + longest_prefix = min_length + + if longest_prefix > 0: + reset = False + self.model.n_tokens = longest_prefix + if len(seq_tensor) - longest_prefix > 0: + self.model.eval(seq[longest_prefix:]) + + if reset: + self.model.reset() + self.model.eval(seq) + + logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(input_ids.device) + else: + self.model.reset() + self.model.eval(seq) + logits = torch.tensor(self.model.eval_logits) + logits = logits.view(1, logits.shape[0], logits.shape[1]).to(input_ids.device) + + if is_negative: + self.save_negative_cache() + self.past_seq_negative = seq_tensor + else: + self.save_cache() + self.past_seq = seq_tensor + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, logits.shape[-1]) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): + assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported" + if isinstance(pretrained_model_name_or_path, str): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + + path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path) + if path.is_file(): + model_file = path + else: + model_file = list(path.glob('*.gguf'))[0] + + logger.info(f"llama.cpp weights detected: {model_file}\n") + + if shared.args.tensor_split is None or shared.args.tensor_split.strip() == '': + tensor_split_list = None + else: + tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")] + + params = { + 'model_path': str(model_file), + 'n_ctx': shared.args.n_ctx, + 'seed': int(shared.args.llama_cpp_seed), + 'n_threads': shared.args.threads or None, + 'n_batch': shared.args.n_batch, + 'use_mmap': not shared.args.no_mmap, + 'use_mlock': shared.args.mlock, + 'mul_mat_q': shared.args.mul_mat_q, + 'low_vram': shared.args.low_vram, + 'n_gpu_layers': shared.args.n_gpu_layers, + 'rope_freq_base': RoPE.get_rope_freq_base(shared.args.alpha_value, shared.args.rope_freq_base), + 'tensor_split': tensor_split_list, + 'rope_freq_scale': 1.0 / shared.args.compress_pos_emb, + 'logits_all': True, + } + + Llama = llama_cpp_lib().Llama + model = Llama(**params) + + return LlamacppHF(model, model_file) diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index 4899ad9..951267e 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -1,19 +1,30 @@ -''' -Based on -https://github.com/abetlen/llama-cpp-python - -Documentation: -https://abetlen.github.io/llama-cpp-python/ -''' - import re from functools import partial -from llama_cpp import Llama, LlamaCache, LogitsProcessorList +import numpy as np +import torch -from modules import shared +from modules import RoPE, shared from modules.callbacks import Iteratorize from modules.logging_colors import logger +from modules.text_generation import get_max_prompt_length + +import llama_cpp + +if torch.cuda.is_available() and not torch.version.hip: + try: + import llama_cpp_cuda + except: + llama_cpp_cuda = None +else: + llama_cpp_cuda = None + + +def llama_cpp_lib(): + if shared.args.cpu or llama_cpp_cuda is None: + return llama_cpp + else: + return llama_cpp_cuda def ban_eos_logits_processor(eos_token, input_ids, logits): @@ -21,6 +32,13 @@ def ban_eos_logits_processor(eos_token, input_ids, logits): return logits +def custom_token_ban_logits_processor(token_ids, input_ids, logits): + for token_id in token_ids: + logits[token_id] = -float('inf') + + return logits + + class LlamaCppModel: def __init__(self): self.initialized = False @@ -30,6 +48,10 @@ class LlamaCppModel: @classmethod def from_pretrained(self, path): + + Llama = llama_cpp_lib().Llama + LlamaCache = llama_cpp_lib().LlamaCache + result = self() cache_capacity = 0 if shared.args.cache_capacity is not None: @@ -41,6 +63,12 @@ class LlamaCppModel: cache_capacity = int(shared.args.cache_capacity) logger.info("Cache capacity is " + str(cache_capacity) + " bytes") + + if shared.args.tensor_split is None or shared.args.tensor_split.strip() == '': + tensor_split_list = None + else: + tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")] + params = { 'model_path': str(path), 'n_ctx': shared.args.n_ctx, @@ -49,7 +77,12 @@ class LlamaCppModel: 'n_batch': shared.args.n_batch, 'use_mmap': not shared.args.no_mmap, 'use_mlock': shared.args.mlock, - 'n_gpu_layers': shared.args.n_gpu_layers + 'mul_mat_q': shared.args.mul_mat_q, + 'low_vram': shared.args.low_vram, + 'n_gpu_layers': shared.args.n_gpu_layers, + 'rope_freq_base': RoPE.get_rope_freq_base(shared.args.alpha_value, shared.args.rope_freq_base), + 'tensor_split': tensor_split_list, + 'rope_freq_scale': 1.0 / shared.args.compress_pos_emb, } result.model = Llama(**params) @@ -65,11 +98,35 @@ class LlamaCppModel: return self.model.tokenize(string) - def decode(self, tokens): - return self.model.detokenize(tokens) + def decode(self, ids): + return self.model.detokenize(ids).decode('utf-8') + + def get_logits(self, tokens): + self.model.eval(tokens) + logits = self.model._scores + logits = np.expand_dims(logits, 0) # batch dim is expected + return torch.tensor(logits, dtype=torch.float32) def generate(self, prompt, state, callback=None): + + LogitsProcessorList = llama_cpp_lib().LogitsProcessorList + prompt = prompt if type(prompt) is str else prompt.decode() + + # Handle truncation + prompt = self.encode(prompt) + prompt = prompt[-get_max_prompt_length(state):] + prompt = self.decode(prompt) + + logit_processors = LogitsProcessorList() + if state['ban_eos_token']: + logit_processors.append(partial(ban_eos_logits_processor, self.model.token_eos())) + + if state['custom_token_bans']: + to_ban = [int(x) for x in state['custom_token_bans'].split(',')] + if len(to_ban) > 0: + logit_processors.append(partial(custom_token_ban_logits_processor, to_ban)) + completion_chunks = self.model.create_completion( prompt=prompt, max_tokens=state['max_new_tokens'], @@ -82,13 +139,13 @@ class LlamaCppModel: mirostat_tau=state['mirostat_tau'], mirostat_eta=state['mirostat_eta'], stream=True, - logits_processor=LogitsProcessorList([ - partial(ban_eos_logits_processor, self.model.token_eos()), - ]) if state['ban_eos_token'] else None, + logits_processor=logit_processors, ) output = "" for completion_chunk in completion_chunks: + if shared.stop_everything: + break text = completion_chunk['choices'][0]['text'] output += text if callback: diff --git a/modules/loaders.py b/modules/loaders.py index 8ec575a..b7187e5 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -1,10 +1,60 @@ import functools +from collections import OrderedDict import gradio as gr from modules import shared -loaders_and_params = { +loaders_and_params = OrderedDict({ + 'Transformers': [ + 'cpu_memory', + 'gpu_memory', + 'trust_remote_code', + 'load_in_8bit', + 'bf16', + 'cpu', + 'disk', + 'auto_devices', + 'load_in_4bit', + 'use_double_quant', + 'quant_type', + 'compute_dtype', + 'trust_remote_code', + 'alpha_value', + 'rope_freq_base', + 'compress_pos_emb', + 'transformers_info' + ], + 'ExLlama_HF': [ + 'gpu_split', + 'max_seq_len', + 'alpha_value', + 'rope_freq_base', + 'compress_pos_emb', + 'cfg_cache', + 'exllama_HF_info', + ], + 'ExLlamav2_HF': [ + 'gpu_split', + 'max_seq_len', + 'cfg_cache', + 'alpha_value', + 'compress_pos_emb', + ], + 'ExLlama': [ + 'gpu_split', + 'max_seq_len', + 'alpha_value', + 'rope_freq_base', + 'compress_pos_emb', + 'exllama_info', + ], + 'ExLlamav2': [ + 'gpu_split', + 'max_seq_len', + 'alpha_value', + 'compress_pos_emb', + ], 'AutoGPTQ': [ 'triton', 'no_inject_fused_attention', @@ -13,6 +63,7 @@ loaders_and_params = { 'wbits', 'groupsize', 'desc_act', + 'disable_exllama', 'gpu_memory', 'cpu_memory', 'cpu', @@ -31,44 +82,315 @@ loaders_and_params = { 'llama.cpp': [ 'n_ctx', 'n_gpu_layers', + 'tensor_split', 'n_batch', 'threads', 'no_mmap', + 'low_vram', 'mlock', + 'mul_mat_q', 'llama_cpp_seed', - ], - 'Transformers': [ - 'cpu_memory', - 'gpu_memory', - 'trust_remote_code', - 'load_in_8bit', - 'bf16', + 'alpha_value', + 'rope_freq_base', + 'compress_pos_emb', 'cpu', - 'disk', - 'auto_devices', - 'load_in_4bit', - 'use_double_quant', - 'quant_type', - 'compute_dtype', - 'trust_remote_code', - 'transformers_info' ], - 'ExLlama' : [ - 'gpu_split', - 'max_seq_len', - 'compress_pos_emb', + 'llamacpp_HF': [ + 'n_ctx', + 'n_gpu_layers', + 'tensor_split', + 'n_batch', + 'threads', + 'no_mmap', + 'low_vram', + 'mlock', + 'mul_mat_q', 'alpha_value', - 'exllama_info', + 'rope_freq_base', + 'compress_pos_emb', + 'cpu', + 'cfg_cache', + 'llamacpp_HF_info', ], - 'ExLlama_HF' : [ - 'gpu_split', - 'max_seq_len', - 'compress_pos_emb', - 'alpha_value', - 'exllama_HF_info', + 'ctransformers': [ + 'n_ctx', + 'n_gpu_layers', + 'n_batch', + 'threads', + 'model_type', + 'no_mmap', + 'mlock' ] +}) + +loaders_samplers = { + 'Transformers': { + 'temperature', + 'top_p', + 'top_k', + 'typical_p', + 'epsilon_cutoff', + 'eta_cutoff', + 'tfs', + 'top_a', + 'repetition_penalty', + 'repetition_penalty_range', + 'encoder_repetition_penalty', + 'no_repeat_ngram_size', + 'min_length', + 'seed', + 'do_sample', + 'penalty_alpha', + 'num_beams', + 'length_penalty', + 'early_stopping', + 'mirostat_mode', + 'mirostat_tau', + 'mirostat_eta', + 'guidance_scale', + 'negative_prompt', + 'ban_eos_token', + 'custom_token_bans', + 'add_bos_token', + 'skip_special_tokens', + 'auto_max_new_tokens', + }, + 'ExLlama_HF': { + 'temperature', + 'top_p', + 'top_k', + 'typical_p', + 'epsilon_cutoff', + 'eta_cutoff', + 'tfs', + 'top_a', + 'repetition_penalty', + 'repetition_penalty_range', + 'encoder_repetition_penalty', + 'no_repeat_ngram_size', + 'min_length', + 'seed', + 'do_sample', + 'mirostat_mode', + 'mirostat_tau', + 'mirostat_eta', + 'guidance_scale', + 'negative_prompt', + 'ban_eos_token', + 'custom_token_bans', + 'add_bos_token', + 'skip_special_tokens', + 'auto_max_new_tokens', + }, + 'ExLlama': { + 'temperature', + 'top_p', + 'top_k', + 'typical_p', + 'repetition_penalty', + 'repetition_penalty_range', + 'seed', + 'guidance_scale', + 'negative_prompt', + 'ban_eos_token', + 'custom_token_bans', + 'auto_max_new_tokens', + }, + 'ExLlamav2': { + 'temperature', + 'top_p', + 'top_k', + 'repetition_penalty', + 'repetition_penalty_range', + 'seed', + 'ban_eos_token', + 'custom_token_bans', + 'auto_max_new_tokens', + }, + 'ExLlamav2_HF': { + 'temperature', + 'top_p', + 'top_k', + 'typical_p', + 'epsilon_cutoff', + 'eta_cutoff', + 'tfs', + 'top_a', + 'repetition_penalty', + 'repetition_penalty_range', + 'encoder_repetition_penalty', + 'no_repeat_ngram_size', + 'min_length', + 'seed', + 'do_sample', + 'mirostat_mode', + 'mirostat_tau', + 'mirostat_eta', + 'guidance_scale', + 'negative_prompt', + 'ban_eos_token', + 'custom_token_bans', + 'add_bos_token', + 'skip_special_tokens', + 'auto_max_new_tokens', + }, + 'AutoGPTQ': { + 'temperature', + 'top_p', + 'top_k', + 'typical_p', + 'epsilon_cutoff', + 'eta_cutoff', + 'tfs', + 'top_a', + 'repetition_penalty', + 'repetition_penalty_range', + 'encoder_repetition_penalty', + 'no_repeat_ngram_size', + 'min_length', + 'seed', + 'do_sample', + 'penalty_alpha', + 'num_beams', + 'length_penalty', + 'early_stopping', + 'mirostat_mode', + 'mirostat_tau', + 'mirostat_eta', + 'guidance_scale', + 'negative_prompt', + 'ban_eos_token', + 'custom_token_bans', + 'add_bos_token', + 'skip_special_tokens', + 'auto_max_new_tokens', + }, + 'GPTQ-for-LLaMa': { + 'temperature', + 'top_p', + 'top_k', + 'typical_p', + 'epsilon_cutoff', + 'eta_cutoff', + 'tfs', + 'top_a', + 'repetition_penalty', + 'repetition_penalty_range', + 'encoder_repetition_penalty', + 'no_repeat_ngram_size', + 'min_length', + 'seed', + 'do_sample', + 'penalty_alpha', + 'num_beams', + 'length_penalty', + 'early_stopping', + 'mirostat_mode', + 'mirostat_tau', + 'mirostat_eta', + 'guidance_scale', + 'negative_prompt', + 'ban_eos_token', + 'custom_token_bans', + 'add_bos_token', + 'skip_special_tokens', + 'auto_max_new_tokens', + }, + 'llama.cpp': { + 'temperature', + 'top_p', + 'top_k', + 'tfs', + 'repetition_penalty', + 'mirostat_mode', + 'mirostat_tau', + 'mirostat_eta', + 'ban_eos_token', + 'custom_token_bans', + }, + 'llamacpp_HF': { + 'temperature', + 'top_p', + 'top_k', + 'typical_p', + 'epsilon_cutoff', + 'eta_cutoff', + 'tfs', + 'top_a', + 'repetition_penalty', + 'repetition_penalty_range', + 'encoder_repetition_penalty', + 'no_repeat_ngram_size', + 'min_length', + 'seed', + 'do_sample', + 'mirostat_mode', + 'mirostat_tau', + 'mirostat_eta', + 'guidance_scale', + 'negative_prompt', + 'ban_eos_token', + 'custom_token_bans', + 'add_bos_token', + 'skip_special_tokens', + 'auto_max_new_tokens', + }, + 'ctransformers': { + 'temperature', + 'top_p', + 'top_k', + 'repetition_penalty', + 'repetition_penalty_range', + } } +loaders_model_types = { + 'GPTQ-for-LLaMa': [ + "None", + "llama", + "opt", + "gptj" + ], + 'ctransformers': [ + "None", + "gpt2", + "gptj", + "gptneox", + "llama", + "mpt", + "dollyv2", + "replit", + "starcoder", + "gptbigcode", + "falcon" + ], +} + + +@functools.cache +def list_all_samplers(): + all_samplers = set() + for k in loaders_samplers: + for sampler in loaders_samplers[k]: + all_samplers.add(sampler) + + return sorted(all_samplers) + + +def blacklist_samplers(loader): + all_samplers = list_all_samplers() + if loader == 'All': + return [gr.update(visible=True) for sampler in all_samplers] + else: + return [gr.update(visible=True) if sampler in loaders_samplers[loader] else gr.update(visible=False) for sampler in all_samplers] + + +def get_model_types(loader): + if loader in loaders_model_types: + return loaders_model_types[loader] + + return ["None"] + def get_gpu_memory_keys(): return [k for k in shared.gradio if k.startswith('gpu_memory')] diff --git a/modules/logits.py b/modules/logits.py new file mode 100644 index 0000000..6fc5bf6 --- /dev/null +++ b/modules/logits.py @@ -0,0 +1,56 @@ +import torch + +from modules import sampler_hijack, shared +from modules.logging_colors import logger +from modules.text_generation import generate_reply + +global_scores = None + + +def get_next_logits(prompt, state, use_samplers, previous): + if shared.model is None: + logger.error("No model is loaded! Select one in the Model tab.") + return 'Error: No model is loaded1 Select one in the Model tab.', previous + + is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model' + is_non_hf_exllamav1 = shared.model.__class__.__name__ == 'ExllamaModel' + is_non_hf_llamacpp = shared.model.__class__.__name__ == 'LlamaCppModel' + + if use_samplers: + if any([is_non_hf_exllamav2, is_non_hf_exllamav1, is_non_hf_llamacpp]): + logger.error("Sampler hijacking is not supported non-Huggingface loaders.") + # sampling is all done in c for exllama, so it is really hard to hijack + # it should be possible to hijack llamacpp sampler by hijacking all their sampling methods, + # but it is not implemented yet + return 'Error: Sampler hijacking is not supported non-Huggingface loaders. Please disable the "Use samplers" option.', previous + + state['max_new_tokens'] = 1 + state['auto_max_new_tokens'] = False + for _ in generate_reply(prompt, state): + pass + + scores = sampler_hijack.global_scores[-1] + else: + if is_non_hf_exllamav2 or is_non_hf_exllamav1: + tokens = shared.tokenizer.encode(prompt).cuda() + scores = shared.model.get_logits(tokens)[-1][-1] + elif is_non_hf_llamacpp: + tokens = shared.tokenizer.encode(prompt) + scores = shared.model.get_logits(tokens)[-1][-1] + else: + tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda() + output = shared.model(input_ids=tokens) + scores = output['logits'][-1][-1] + + probs = torch.softmax(scores, dim=-1, dtype=torch.float) + topk_values, topk_indices = torch.topk(probs, k=50, largest=True, sorted=True) + topk_values = [f"{float(i):.5f}" for i in topk_values] + if is_non_hf_exllamav1 or is_non_hf_llamacpp: + topk_indices = [i.expand((1, 1)) for i in topk_indices] + + tokens = [shared.tokenizer.decode(i) for i in topk_indices] + output = '' + for row in list(zip(topk_values, tokens)): + output += f"{row[0]} - {repr(row[1])}\n" + + return output, previous diff --git a/modules/metadata_gguf.py b/modules/metadata_gguf.py new file mode 100644 index 0000000..0ea41a2 --- /dev/null +++ b/modules/metadata_gguf.py @@ -0,0 +1,91 @@ +import struct +from enum import IntEnum + + +class GGUFValueType(IntEnum): + UINT8 = 0 + INT8 = 1 + UINT16 = 2 + INT16 = 3 + UINT32 = 4 + INT32 = 5 + FLOAT32 = 6 + BOOL = 7 + STRING = 8 + ARRAY = 9 + UINT64 = 10 + INT64 = 11 + FLOAT64 = 12 + + +_simple_value_packing = { + GGUFValueType.UINT8: " 1, shared.args.alpha_value > 1]): model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=shared.args.trust_remote_code) - if torch.has_mps: + if torch.backends.mps.is_available(): device = torch.device('mps') model = model.to(device) else: @@ -166,7 +148,7 @@ def huggingface_loader(model_name): "trust_remote_code": shared.args.trust_remote_code } - if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)): + if not any((shared.args.cpu, torch.cuda.is_available(), torch.backends.mps.is_available())): logger.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.") shared.args.cpu = True @@ -215,37 +197,16 @@ def huggingface_loader(model_name): no_split_module_classes=model._no_split_modules ) + if shared.args.compress_pos_emb > 1: + params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb} + elif shared.args.alpha_value > 1: + params['rope_scaling'] = {'type': 'dynamic', 'factor': RoPE.get_alpha_value(shared.args.alpha_value, shared.args.rope_freq_base)} + model = LoaderClass.from_pretrained(checkpoint, **params) return model -def flexgen_loader(model_name): - from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy - - # Initialize environment - env = ExecutionEnv.create(shared.args.disk_cache_dir) - - # Offloading policy - policy = Policy(1, 1, - shared.args.percent[0], shared.args.percent[1], - shared.args.percent[2], shared.args.percent[3], - shared.args.percent[4], shared.args.percent[5], - overlap=True, sep_layer=True, pin_weight=shared.args.pin_weight, - cpu_cache_compute=False, attn_sparsity=1.0, - compress_weight=shared.args.compress_weight, - comp_weight_config=CompressionConfig( - num_bits=4, group_size=64, - group_dim=0, symmetric=False), - compress_cache=False, - comp_cache_config=CompressionConfig( - num_bits=4, group_size=64, - group_dim=2, symmetric=False)) - - model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy) - return model - - def RWKV_loader(model_name): from modules.RWKV import RWKVModel, RWKVTokenizer @@ -261,13 +222,62 @@ def llamacpp_loader(model_name): if path.is_file(): model_file = path else: - model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0] + model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*.gguf'))[0] - logger.info(f"llama.cpp weights detected: {model_file}\n") + logger.info(f"llama.cpp weights detected: {model_file}") model, tokenizer = LlamaCppModel.from_pretrained(model_file) return model, tokenizer +def llamacpp_HF_loader(model_name): + from modules.llamacpp_hf import LlamacppHF + + for fname in [model_name, "oobabooga_llama-tokenizer", "llama-tokenizer"]: + path = Path(f'{shared.args.model_dir}/{fname}') + if all((path / file).exists() for file in ['tokenizer_config.json', 'special_tokens_map.json', 'tokenizer.model']): + logger.info(f'Using tokenizer from: {path}') + break + else: + logger.error("Could not load the model because a tokenizer in transformers format was not found. Please download oobabooga/llama-tokenizer.") + return None, None + + tokenizer = AutoTokenizer.from_pretrained( + path, + trust_remote_code=shared.args.trust_remote_code, + use_fast=False + ) + + model = LlamacppHF.from_pretrained(model_name) + return model, tokenizer + + +def ctransformers_loader(model_name): + from modules.ctransformers_model import CtransformersModel + + path = Path(f'{shared.args.model_dir}/{model_name}') + ctrans = CtransformersModel() + if ctrans.model_type_is_auto(): + model_file = path + else: + if path.is_file(): + model_file = path + else: + entries = Path(f'{shared.args.model_dir}/{model_name}') + gguf = list(entries.glob('*.gguf')) + bin = list(entries.glob('*.bin')) + if len(gguf) > 0: + model_file = gguf[0] + elif len(bin) > 0: + model_file = bin[0] + else: + logger.error("Could not find a model for ctransformers.") + return None, None + + logger.info(f'ctransformers weights detected: {model_file}') + model, tokenizer = ctrans.from_pretrained(model_file) + return model, tokenizer + + def GPTQ_loader(model_name): # Monkey patch @@ -305,6 +315,19 @@ def ExLlama_HF_loader(model_name): return ExllamaHF.from_pretrained(model_name) +def ExLlamav2_loader(model_name): + from modules.exllamav2 import Exllamav2Model + + model, tokenizer = Exllamav2Model.from_pretrained(model_name) + return model, tokenizer + + +def ExLlamav2_HF_loader(model_name): + from modules.exllamav2_hf import Exllamav2HF + + return Exllamav2HF.from_pretrained(model_name) + + def get_max_memory_dict(): max_memory = {} if shared.args.gpu_memory: @@ -339,6 +362,7 @@ def clear_torch_cache(): def unload_model(): shared.model = shared.tokenizer = None shared.lora_names = [] + shared.model_dirty_from_training = False clear_torch_cache() diff --git a/modules/models_settings.py b/modules/models_settings.py index 0207e7d..bc3ace6 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -3,12 +3,59 @@ from pathlib import Path import yaml -from modules import shared, ui +from modules import loaders, metadata_gguf, shared, ui -def get_model_settings_from_yamls(model): - settings = shared.model_config +def get_fallback_settings(): + return { + 'wbits': 'None', + 'model_type': 'None', + 'groupsize': 'None', + 'pre_layer': 0, + 'skip_special_tokens': shared.settings['skip_special_tokens'], + 'custom_stopping_strings': shared.settings['custom_stopping_strings'], + 'truncation_length': shared.settings['truncation_length'], + 'n_ctx': 2048, + 'rope_freq_base': 0, + 'compress_pos_emb': 1, + } + + +def get_model_metadata(model): model_settings = {} + + # Get settings from models/config.yaml and models/config-user.yaml + settings = shared.model_config + for pat in settings: + if re.match(pat.lower(), model.lower()): + for k in settings[pat]: + model_settings[k] = settings[pat][k] + + if 'loader' not in model_settings: + loader = infer_loader(model, model_settings) + if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0: + loader = 'AutoGPTQ' + + model_settings['loader'] = loader + + # Read GGUF metadata + if model_settings['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']: + path = Path(f'{shared.args.model_dir}/{model}') + if path.is_file(): + model_file = path + else: + model_file = list(path.glob('*.gguf'))[0] + + metadata = metadata_gguf.load_metadata(model_file) + if 'llama.context_length' in metadata: + model_settings['n_ctx'] = metadata['llama.context_length'] + if 'llama.rope.scale_linear' in metadata: + model_settings['compress_pos_emb'] = metadata['llama.rope.scale_linear'] + if 'llama.rope.freq_base' in metadata: + model_settings['rope_freq_base'] = metadata['llama.rope.freq_base'] + + # Apply user settings from models/config-user.yaml + settings = shared.user_config for pat in settings: if re.match(pat.lower(), model.lower()): for k in settings[pat]: @@ -17,21 +64,18 @@ def get_model_settings_from_yamls(model): return model_settings -def infer_loader(model_name): +def infer_loader(model_name, model_settings): path_to_model = Path(f'{shared.args.model_dir}/{model_name}') - model_settings = get_model_settings_from_yamls(model_name) if not path_to_model.exists(): loader = None elif Path(f'{shared.args.model_dir}/{model_name}/quantize_config.json').exists() or ('wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0): loader = 'AutoGPTQ' - elif len(list(path_to_model.glob('*ggml*.bin'))) > 0: + elif len(list(path_to_model.glob('*.gguf'))) > 0: loader = 'llama.cpp' - elif re.match('.*ggml.*\.bin', model_name.lower()): + elif re.match(r'.*\.gguf', model_name.lower()): loader = 'llama.cpp' - elif re.match('.*rwkv.*\.pth', model_name.lower()): + elif re.match(r'.*rwkv.*\.pth', model_name.lower()): loader = 'RWKV' - elif shared.args.flexgen: - loader = 'FlexGen' else: loader = 'Transformers' @@ -52,7 +96,7 @@ def update_model_parameters(state, initial=False): gpu_memories.append(value) continue - if initial and vars(shared.args)[element] != vars(shared.args_defaults)[element]: + if initial and element in shared.provided_arguments: continue # Setting null defaults @@ -87,19 +131,20 @@ def update_model_parameters(state, initial=False): # UI: update the state variable with the model settings def apply_model_settings_to_state(model, state): - model_settings = get_model_settings_from_yamls(model) - if 'loader' not in model_settings: - loader = infer_loader(model) - if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0: - loader = 'AutoGPTQ' + model_settings = get_model_metadata(model) + if 'loader' in model_settings: + loader = model_settings.pop('loader') - # If the user is using an alternative GPTQ loader, let them keep using it - if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlama', 'ExLlama_HF']): + # If the user is using an alternative loader for the same model type, let them keep using it + if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlama', 'ExLlama_HF', 'ExLlamav2', 'ExLlamav2_HF']) and not (loader == 'llama.cpp' and state['loader'] in ['llamacpp_HF', 'ctransformers']): state['loader'] = loader for k in model_settings: if k in state: - state[k] = model_settings[k] + if k in ['wbits', 'groupsize']: + state[k] = str(model_settings[k]) + else: + state[k] = model_settings[k] return state @@ -117,18 +162,17 @@ def save_model_settings(model, state): user_config = {} model_regex = model + '$' # For exact matches - for _dict in [user_config, shared.model_config]: - if model_regex not in _dict: - _dict[model_regex] = {} - if model_regex not in user_config: user_config[model_regex] = {} for k in ui.list_model_elements(): - user_config[model_regex][k] = state[k] - shared.model_config[model_regex][k] = state[k] + if k == 'loader' or k in loaders.loaders_and_params[state['loader']]: + user_config[model_regex][k] = state[k] + shared.user_config = user_config + + output = yaml.dump(user_config, sort_keys=False) with open(p, 'w') as f: - f.write(yaml.dump(user_config, sort_keys=False)) + f.write(output) yield (f"Settings for {model} saved to {p}") diff --git a/modules/monkey_patch_gptq_lora.py b/modules/monkey_patch_gptq_lora.py index bf8d478..3166bd3 100644 --- a/modules/monkey_patch_gptq_lora.py +++ b/modules/monkey_patch_gptq_lora.py @@ -1,39 +1,35 @@ # Copied from https://github.com/johnsmith0031/alpaca_lora_4bit -import sys from pathlib import Path -sys.path.insert(0, str(Path("repositories/alpaca_lora_4bit"))) - -import autograd_4bit -from amp_wrapper import AMPWrapper -from autograd_4bit import ( +import alpaca_lora_4bit.autograd_4bit as autograd_4bit +from alpaca_lora_4bit.amp_wrapper import AMPWrapper +from alpaca_lora_4bit.autograd_4bit import ( Autograd4bitQuantLinear, load_llama_model_4bit_low_ram ) -from monkeypatch.peft_tuners_lora_monkey_patch import ( - Linear4bitLt, - replace_peft_model_with_gptq_lora_model +from alpaca_lora_4bit.models import Linear4bitLt +from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( + replace_peft_model_with_int4_lora_model ) from modules import shared from modules.GPTQ_loader import find_quantized_model_file -replace_peft_model_with_gptq_lora_model() +replace_peft_model_with_int4_lora_model() def load_model_llama(model_name): config_path = str(Path(f'{shared.args.model_dir}/{model_name}')) model_path = str(find_quantized_model_file(model_name)) model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=shared.args.groupsize, is_v1_model=False) - for n, m in model.named_modules(): + for _, m in model.named_modules(): if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt): if m.is_v1_model: m.zeros = m.zeros.half() m.scales = m.scales.half() m.bias = m.bias.half() - autograd_4bit.use_new = True autograd_4bit.auto_switch = True model.half() diff --git a/modules/presets.py b/modules/presets.py index 0af2928..96d6e99 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -4,11 +4,12 @@ from pathlib import Path import yaml -def load_preset(name): - generate_params = { +def default_preset(): + return { 'do_sample': True, 'temperature': 1, 'top_p': 1, + 'top_k': 0, 'typical_p': 1, 'epsilon_cutoff': 0, 'eta_cutoff': 0, @@ -17,18 +18,26 @@ def load_preset(name): 'repetition_penalty': 1, 'repetition_penalty_range': 0, 'encoder_repetition_penalty': 1, - 'top_k': 0, - 'num_beams': 1, - 'penalty_alpha': 0, - 'min_length': 0, - 'length_penalty': 1, 'no_repeat_ngram_size': 0, - 'early_stopping': False, + 'min_length': 0, + 'guidance_scale': 1, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1, + 'penalty_alpha': 0, + 'num_beams': 1, + 'length_penalty': 1, + 'early_stopping': False, + 'custom_token_bans': '', } + +def presets_params(): + return [k for k in default_preset()] + + +def load_preset(name): + generate_params = default_preset() if name not in ['None', None, '']: with open(Path(f'presets/{name}.yaml'), 'r') as infile: preset = yaml.safe_load(infile) @@ -48,9 +57,16 @@ def load_preset_memoized(name): def load_preset_for_ui(name, state): generate_params = load_preset(name) state.update(generate_params) - return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']] + return state, *[generate_params[k] for k in presets_params()] def generate_preset_yaml(state): - data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']} + defaults = default_preset() + data = {k: state[k] for k in presets_params()} + + # Remove entries that are identical to the defaults + for k in list(data.keys()): + if data[k] == defaults[k]: + del data[k] + return yaml.dump(data, sort_keys=False) diff --git a/modules/prompts.py b/modules/prompts.py new file mode 100644 index 0000000..ce652de --- /dev/null +++ b/modules/prompts.py @@ -0,0 +1,51 @@ +from pathlib import Path + +import yaml + +from modules import utils +from modules.text_generation import get_encoded_length + + +def load_prompt(fname): + if fname in ['None', '']: + return '' + else: + file_path = Path(f'prompts/{fname}.txt') + if not file_path.exists(): + return '' + + with open(file_path, 'r', encoding='utf-8') as f: + text = f.read() + if text[-1] == '\n': + text = text[:-1] + + return text + + +def load_instruction_prompt_simple(fname): + file_path = Path(f'instruction-templates/{fname}.yaml') + if not file_path.exists(): + return '' + + with open(file_path, 'r', encoding='utf-8') as f: + data = yaml.safe_load(f) + output = '' + if 'context' in data: + output += data['context'] + + replacements = { + '<|user|>': data['user'], + '<|bot|>': data['bot'], + '<|user-message|>': 'Input', + } + + output += utils.replace_all(data['turn_template'].split('<|bot-message|>')[0], replacements) + return output.rstrip(' ') + + +def count_tokens(text): + try: + tokens = get_encoded_length(text) + return str(tokens) + except: + return '0' diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 391ece9..0a724f4 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -10,6 +10,8 @@ from transformers.generation.logits_process import ( TemperatureLogitsWarper ) +global_scores = None + class TailFreeLogitsWarper(LogitsWarper): def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): @@ -104,7 +106,7 @@ class MirostatLogitsWarper(LogitsWarper): break # Normalize the probabilities of the remaining words - prob_topk = torch.softmax(sorted_logits, dim=0) + prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda') prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda') @@ -122,10 +124,21 @@ class MirostatLogitsWarper(LogitsWarper): return scores +class SpyLogitsWarper(LogitsWarper): + def __init__(self): + pass + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + global global_scores + global_scores = scores + return scores + + class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor): ''' Copied from the transformers library ''' + def __init__(self, penalty: float, _range: int): if not isinstance(penalty, float) or not (penalty > 0): raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") @@ -167,6 +180,7 @@ def get_logits_warper_patch(self, generation_config): else: warpers += warpers_to_add + warpers.append(SpyLogitsWarper()) return warpers diff --git a/modules/shared.py b/modules/shared.py index 2b2fa06..e534af2 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,4 +1,5 @@ import argparse +import sys from collections import OrderedDict from pathlib import Path @@ -6,62 +7,56 @@ import yaml from modules.logging_colors import logger -generation_lock = None +# Model variables model = None tokenizer = None -is_seq2seq = False model_name = "None" +is_seq2seq = False +model_dirty_from_training = False lora_names = [] -# Chat variables +# Generation variables stop_everything = False +generation_lock = None processing_message = '*Is typing...*' -# UI elements (buttons, sliders, HTML, etc) +# UI variables gradio = {} - -# For keeping the values of UI elements on page reload persistent_interface_state = {} - -input_params = [] # Generation input parameters -reload_inputs = [] # Parameters for reloading the chat interface - -# For restarting the interface need_restart = False +# UI defaults settings = { - 'dark_theme': False, - 'autoload_model': True, + 'dark_theme': True, + 'show_controls': True, + 'start_with': '', + 'mode': 'chat', + 'chat_style': 'cai-chat', + 'prompt-default': 'QA', + 'prompt-notebook': 'QA', + 'preset': 'simple-1', 'max_new_tokens': 200, 'max_new_tokens_min': 1, - 'max_new_tokens_max': 2000, + 'max_new_tokens_max': 4096, 'seed': -1, - 'character': 'None', - 'name1': 'You', - 'name2': 'Assistant', - 'context': 'This is a conversation with your Assistant. It is a computer program designed to help you with various tasks such as answering questions, providing recommendations, and helping with decision making. You can ask it anything you want and it will do its best to give you accurate and relevant information.', - 'greeting': '', - 'turn_template': '', - 'custom_stopping_strings': '', - 'stop_at_newline': False, - 'add_bos_token': True, - 'ban_eos_token': False, - 'skip_special_tokens': True, + 'negative_prompt': '', 'truncation_length': 2048, 'truncation_length_min': 0, 'truncation_length_max': 16384, - 'mode': 'chat', - 'start_with': '', - 'chat_style': 'cai-chat', - 'instruction_template': 'None', + 'custom_stopping_strings': '', + 'auto_max_new_tokens': False, + 'max_tokens_second': 0, + 'ban_eos_token': False, + 'custom_token_bans': '', + 'add_bos_token': True, + 'skip_special_tokens': True, + 'stream': True, + 'name1': 'You', + 'character': 'Assistant', + 'instruction_template': 'Alpaca', 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', - 'chat_generation_attempts': 1, - 'chat_generation_attempts_min': 1, - 'chat_generation_attempts_max': 10, - 'default_extensions': [], - 'chat_default_extensions': ['gallery'], - 'preset': 'simple-1', - 'prompt': 'QA', + 'autoload_model': False, + 'default_extensions': ['gallery'], } @@ -79,8 +74,8 @@ def str2bool(v): parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54)) # Basic settings -parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') -parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.') +parser.add_argument('--notebook', action='store_true', help='DEPRECATED') +parser.add_argument('--chat', action='store_true', help='DEPRECATED') parser.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. WARNING: this is highly experimental.') parser.add_argument('--character', type=str, help='The name of the character to load in chat mode by default.') parser.add_argument('--model', type=str, help='Name of the model to load by default.') @@ -88,13 +83,14 @@ parser.add_argument('--lora', type=str, nargs="+", help='The list of LoRAs to lo parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models") parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras") parser.add_argument('--model-menu', action='store_true', help='Show a model menu in the terminal when the web UI is first launched.') -parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.') +parser.add_argument('--no-stream', action='store_true', help='DEPRECATED') parser.add_argument('--settings', type=str, help='Load the default interface settings from this yaml file. See settings-template.yaml for an example. If you create a file called settings.yaml, this file will be loaded by default without the need to use the --settings flag.') parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') +parser.add_argument('--chat-buttons', action='store_true', help='Show buttons on chat tab instead of hover menu.') # Model loader -parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, exllama_hf, llamacpp, rwkv, flexgen') +parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, exllama_hf, llamacpp, rwkv') # Accelerate/transformers parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.') @@ -120,9 +116,12 @@ parser.add_argument('--use_double_quant', action='store_true', help='use_double_ parser.add_argument('--threads', type=int, default=0, help='Number of threads to use.') parser.add_argument('--n_batch', type=int, default=512, help='Maximum number of prompt tokens to batch together when calling llama_eval.') parser.add_argument('--no-mmap', action='store_true', help='Prevent mmap from being used.') +parser.add_argument('--low-vram', action='store_true', help='Low VRAM Mode') parser.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.') +parser.add_argument('--mul_mat_q', action='store_true', help='Activate new mulmat kernels.') parser.add_argument('--cache-capacity', type=str, help='Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.') parser.add_argument('--n-gpu-layers', type=int, default=0, help='Number of layers to offload to the GPU.') +parser.add_argument('--tensor_split', type=str, default=None, help="Split the model across multiple GPUs, comma-separated list of proportions, e.g. 18,17") parser.add_argument('--n_ctx', type=int, default=2048, help='Size of the prompt context.') parser.add_argument('--llama_cpp_seed', type=int, default=0, help='Seed for llama-cpp models. Default 0 (random)') @@ -133,30 +132,19 @@ parser.add_argument('--groupsize', type=int, default=-1, help='Group size.') parser.add_argument('--pre_layer', type=int, nargs="+", help='The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. For multi-gpu, write the numbers separated by spaces, eg --pre_layer 30 60.') parser.add_argument('--checkpoint', type=str, help='The path to the quantized checkpoint file. If not specified, it will be automatically detected.') parser.add_argument('--monkey-patch', action='store_true', help='Apply the monkey patch for using LoRAs with quantized models.') -parser.add_argument('--quant_attn', action='store_true', help='(triton) Enable quant attention.') -parser.add_argument('--warmup_autotune', action='store_true', help='(triton) Enable warmup autotune.') -parser.add_argument('--fused_mlp', action='store_true', help='(triton) Enable fused mlp.') # AutoGPTQ -parser.add_argument('--gptq-for-llama', action='store_true', help='DEPRECATED') -parser.add_argument('--autogptq', action='store_true', help='DEPRECATED') parser.add_argument('--triton', action='store_true', help='Use triton.') parser.add_argument('--no_inject_fused_attention', action='store_true', help='Do not use fused attention (lowers VRAM requirements).') parser.add_argument('--no_inject_fused_mlp', action='store_true', help='Triton mode only: Do not use fused MLP (lowers VRAM requirements).') parser.add_argument('--no_use_cuda_fp16', action='store_true', help='This can make models faster on some systems.') parser.add_argument('--desc_act', action='store_true', help='For models that don\'t have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig.') +parser.add_argument('--disable_exllama', action='store_true', help='Disable ExLlama kernel, which can improve inference speed on some systems.') # ExLlama parser.add_argument('--gpu-split', type=str, help="Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. 20,7,7") parser.add_argument('--max_seq_len', type=int, default=2048, help="Maximum sequence length.") -parser.add_argument('--compress_pos_emb', type=int, default=1, help="Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.") -parser.add_argument('--alpha_value', type=int, default=1, help="Positional embeddings alpha factor for NTK RoPE scaling. Same as above. Use either this or compress_pos_emb, not both.") - -# FlexGen -parser.add_argument('--flexgen', action='store_true', help='DEPRECATED') -parser.add_argument('--percent', type=int, nargs="+", default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).') -parser.add_argument("--compress-weight", action="store_true", help="FlexGen: activate weight compression.") -parser.add_argument("--pin-weight", type=str2bool, nargs="?", const=True, default=True, help="FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%%).") +parser.add_argument('--cfg-cache', action='store_true', help="ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama.") # DeepSpeed parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') @@ -167,6 +155,11 @@ parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Option parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".') parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.') +# RoPE +parser.add_argument('--alpha_value', type=float, default=1, help="Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both.") +parser.add_argument('--rope_freq_base', type=int, default=0, help="If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63).") +parser.add_argument('--compress_pos_emb', type=int, default=1, help="Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.") + # Gradio parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') parser.add_argument('--listen-host', type=str, help='The hostname that the server will use.') @@ -175,29 +168,31 @@ parser.add_argument('--share', action='store_true', help='Create a public URL. T parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None) +parser.add_argument("--ssl-keyfile", type=str, help='The path to the SSL certificate key file.', default=None) +parser.add_argument("--ssl-certfile", type=str, help='The path to the SSL certificate cert file.', default=None) # API parser.add_argument('--api', action='store_true', help='Enable the API extension.') parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.') -parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.') +parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.') parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.') +parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None) # Multimodal parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.') args = parser.parse_args() args_defaults = parser.parse_args([]) +provided_arguments = [] +for arg in sys.argv[1:]: + arg = arg.lstrip('-').replace('-', '_') + if hasattr(args, arg): + provided_arguments.append(arg) # Deprecation warnings -if args.autogptq: - logger.warning('--autogptq has been deprecated and will be removed soon. Use --loader autogptq instead.') - args.loader = 'autogptq' -if args.gptq_for_llama: - logger.warning('--gptq-for-llama has been deprecated and will be removed soon. Use --loader gptq-for-llama instead.') - args.loader = 'gptq-for-llama' -if args.flexgen: - logger.warning('--flexgen has been deprecated and will be removed soon. Use --loader flexgen instead.') - args.loader = 'FlexGen' +for k in ['chat', 'notebook', 'no_stream']: + if getattr(args, k): + logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.') # Security warnings if args.trust_remote_code: @@ -209,9 +204,14 @@ if args.multi_user: def fix_loader_name(name): + if not name: + return name + name = name.lower() if name in ['llamacpp', 'llama.cpp', 'llama-cpp', 'llama cpp']: return 'llama.cpp' + if name in ['llamacpp_hf', 'llama.cpp_hf', 'llama-cpp-hf', 'llamacpp-hf', 'llama.cpp-hf']: + return 'llamacpp_HF' elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']: return 'Transformers' elif name in ['autogptq', 'auto-gptq', 'auto_gptq', 'auto gptq']: @@ -222,10 +222,12 @@ def fix_loader_name(name): return 'ExLlama' elif name in ['exllama-hf', 'exllama_hf', 'exllama hf', 'ex-llama-hf', 'ex_llama_hf']: return 'ExLlama_HF' - - -if args.loader is not None: - args.loader = fix_loader_name(args.loader) + elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2', 'exllama2', 'exllama-2']: + return 'ExLlamav2' + elif name in ['exllamav2-hf', 'exllamav2_hf', 'exllama-v2-hf', 'exllama_v2_hf', 'exllama-v2_hf', 'exllama2-hf', 'exllama2_hf', 'exllama-2-hf', 'exllama_2_hf', 'exllama-2_hf']: + return 'ExLlamav2_HF' + elif name in ['ctransformers', 'ctranforemrs', 'ctransformer']: + return 'ctransformers' def add_extension(name): @@ -235,43 +237,33 @@ def add_extension(name): args.extensions.append(name) -# Activating the API extension +def is_chat(): + return True + + +args.loader = fix_loader_name(args.loader) + +# Activate the API extension if args.api or args.public_api: add_extension('api') -# Activating the multimodal extension +# Activate the multimodal extension if args.multimodal_pipeline is not None: add_extension('multimodal') - -def is_chat(): - return args.chat - - -def get_mode(): - if args.chat: - return 'chat' - elif args.notebook: - return 'notebook' - else: - return 'default' - - -# Loading model-specific settings +# Load model-specific settings with Path(f'{args.model_dir}/config.yaml') as p: if p.exists(): model_config = yaml.safe_load(open(p, 'r').read()) else: model_config = {} -# Applying user-defined model settings +# Load custom model-specific settings with Path(f'{args.model_dir}/config-user.yaml') as p: if p.exists(): user_config = yaml.safe_load(open(p, 'r').read()) - for k in user_config: - if k in model_config: - model_config[k].update(user_config[k]) - else: - model_config[k] = user_config[k] + else: + user_config = {} model_config = OrderedDict(model_config) +user_config = OrderedDict(user_config) diff --git a/modules/text_generation.py b/modules/text_generation.py index b7f6edf..ab556a9 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -1,5 +1,6 @@ import ast import copy +import html import random import re import time @@ -8,6 +9,7 @@ import traceback import numpy as np import torch import transformers +from transformers import LogitsProcessorList import modules.shared as shared from modules.callbacks import ( @@ -30,15 +32,87 @@ def generate_reply(*args, **kwargs): shared.generation_lock.release() -def get_max_prompt_length(state): - return state['truncation_length'] - state['max_new_tokens'] +def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False): + + # Find the appropriate generation function + generate_func = apply_extensions('custom_generate_reply') + if generate_func is None: + if shared.model_name == 'None' or shared.model is None: + logger.error("No model is loaded! Select one in the Model tab.") + yield '' + return + + if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel']: + generate_func = generate_reply_custom + else: + generate_func = generate_reply_HF + + # Prepare the input + original_question = question + if not is_chat: + state = apply_extensions('state', state) + question = apply_extensions('input', question, state) + + # Find the stopping strings + all_stop_strings = [] + for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")): + if type(st) is list and len(st) > 0: + all_stop_strings += st + + if shared.args.verbose: + print(f'\n\n{question}\n--------------------\n') + + shared.stop_everything = False + clear_torch_cache() + seed = set_manual_seed(state['seed']) + last_update = -1 + reply = '' + is_stream = state['stream'] + if len(all_stop_strings) > 0 and not state['stream']: + state = copy.deepcopy(state) + state['stream'] = True + + # Generate + for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat): + if escape_html: + reply = html.escape(reply) + + reply, stop_found = apply_stopping_strings(reply, all_stop_strings) + if is_stream: + cur_time = time.time() + + # Maximum number of tokens/second + if state['max_tokens_second'] > 0: + diff = 1 / state['max_tokens_second'] - (cur_time - last_update) + if diff > 0: + time.sleep(diff) + + last_update = time.time() + yield reply + + # Limit updates to 24 per second to not stress low latency networks + else: + if cur_time - last_update > 0.041666666666666664: + last_update = cur_time + yield reply + + if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything): + break + + if not is_chat: + reply = apply_extensions('output', reply, state) + + yield reply def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None): - if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']: + if shared.tokenizer is None: + raise ValueError('No tokenizer is loaded') + + if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel', 'Exllamav2Model']: input_ids = shared.tokenizer.encode(str(prompt)) - input_ids = np.array(input_ids).reshape(1, len(input_ids)) - return input_ids + if shared.model.__class__.__name__ not in ['Exllamav2Model']: + input_ids = np.array(input_ids).reshape(1, len(input_ids)) else: input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens) @@ -50,19 +124,24 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if truncation_length is not None: input_ids = input_ids[:, -truncation_length:] - if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel'] or shared.args.cpu: return input_ids - elif shared.args.flexgen: - return input_ids.numpy() elif shared.args.deepspeed: return input_ids.to(device=local_rank) - elif torch.has_mps: + elif torch.backends.mps.is_available(): device = torch.device('mps') return input_ids.to(device) else: return input_ids.cuda() +def decode(output_ids, skip_special_tokens=True): + if shared.tokenizer is None: + raise ValueError('No tokenizer is loaded') + + return shared.tokenizer.decode(output_ids, skip_special_tokens) + + def get_encoded_length(prompt): length_after_extensions = apply_extensions('tokenized_length', prompt) if length_after_extensions is not None: @@ -71,12 +150,47 @@ def get_encoded_length(prompt): return len(encode(prompt)[0]) -def decode(output_ids, skip_special_tokens=True): - return shared.tokenizer.decode(output_ids, skip_special_tokens) +def get_token_ids(prompt): + tokens = encode(prompt)[0] + decoded_tokens = [shared.tokenizer.decode([i]) for i in tokens] + + output = '' + for row in list(zip(tokens, decoded_tokens)): + output += f"{str(int(row[0])).ljust(5)} - {repr(row[1])}\n" + + return output + + +def get_max_prompt_length(state): + return state['truncation_length'] - state['max_new_tokens'] + + +def generate_reply_wrapper(question, state, stopping_strings=None): + """ + Returns formatted outputs for the UI + """ + reply = question if not shared.is_seq2seq else '' + yield formatted_outputs(reply, shared.model_name) + + for reply in generate_reply(question, state, stopping_strings, is_chat=False, escape_html=True): + if not shared.is_seq2seq: + reply = question + reply + + yield formatted_outputs(reply, shared.model_name) + + +def formatted_outputs(reply, model_name): + if any(s in model_name for s in ['gpt-4chan', 'gpt4chan']): + reply = fix_gpt4chan(reply) + return html.unescape(reply), generate_4chan_html(reply) + else: + return html.unescape(reply), generate_basic_html(reply) -# Removes empty replies from gpt4chan outputs def fix_gpt4chan(s): + """ + Removes empty replies from gpt4chan outputs + """ for i in range(10): s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s) s = re.sub("--- [0-9]*\n *\n---", "---", s) @@ -85,8 +199,10 @@ def fix_gpt4chan(s): return s -# Fix the LaTeX equations in galactica def fix_galactica(s): + """ + Fix the LaTeX equations in GALACTICA + """ s = s.replace(r'\[', r'$') s = s.replace(r'\]', r'$') s = s.replace(r'\(', r'$') @@ -111,14 +227,6 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, i return reply -def formatted_outputs(reply, model_name): - if any(s in model_name for s in ['gpt-4chan', 'gpt4chan']): - reply = fix_gpt4chan(reply) - return reply, generate_4chan_html(reply) - else: - return reply, generate_basic_html(reply) - - def set_manual_seed(seed): seed = int(seed) if seed == -1: @@ -135,17 +243,6 @@ def stop_everything_event(): shared.stop_everything = True -def generate_reply_wrapper(question, state, stopping_strings=None): - reply = question if not shared.is_seq2seq else '' - yield formatted_outputs(reply, shared.model_name) - - for reply in generate_reply(question, state, stopping_strings, is_chat=False): - if not shared.is_seq2seq: - reply = question + reply - - yield formatted_outputs(reply, shared.model_name) - - def apply_stopping_strings(reply, all_stop_strings): stop_found = False for string in all_stop_strings: @@ -171,68 +268,14 @@ def apply_stopping_strings(reply, all_stop_strings): return reply, stop_found -def _generate_reply(question, state, stopping_strings=None, is_chat=False): - generate_func = apply_extensions('custom_generate_reply') - if generate_func is None: - if shared.model_name == 'None' or shared.model is None: - logger.error("No model is loaded! Select one in the Model tab.") - yield '' - return - - if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']: - generate_func = generate_reply_custom - elif shared.args.flexgen: - generate_func = generate_reply_flexgen - else: - generate_func = generate_reply_HF - - # Preparing the input - original_question = question - if not is_chat: - state = apply_extensions('state', state) - question = apply_extensions('input', question, state) - - # Finding the stopping strings - all_stop_strings = [] - for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")): - if type(st) is list and len(st) > 0: - all_stop_strings += st - - if shared.args.verbose: - print(f'\n\n{question}\n--------------------\n') - - shared.stop_everything = False - clear_torch_cache() - seed = set_manual_seed(state['seed']) - last_update = -1 - reply = '' - is_stream = state['stream'] - if len(all_stop_strings) > 0 and not state['stream']: - state = copy.deepcopy(state) - state['stream'] = True - - for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat): - reply, stop_found = apply_stopping_strings(reply, all_stop_strings) - if is_stream: - cur_time = time.time() - if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps - last_update = cur_time - yield reply - - if stop_found: - break - - if not is_chat: - reply = apply_extensions('output', reply, state) - - yield reply - - def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): generate_params = {} - for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']: + for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']: generate_params[k] = state[k] + if state['negative_prompt'] != '': + generate_params['negative_prompt_ids'] = encode(state['negative_prompt']) + for k in ['epsilon_cutoff', 'eta_cutoff']: if state[k] > 0: generate_params[k] = state[k] * 1e-4 @@ -240,9 +283,15 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings if state['ban_eos_token']: generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id] - if shared.args.no_cache: - generate_params.update({'use_cache': False}) + if state['custom_token_bans']: + to_ban = [int(x) for x in state['custom_token_bans'].split(',')] + if len(to_ban) > 0: + if generate_params.get('suppress_tokens', None): + generate_params['suppress_tokens'] += to_ban + else: + generate_params['suppress_tokens'] = to_ban + generate_params.update({'use_cache': not shared.args.no_cache}) if shared.args.deepspeed: generate_params.update({'synced_gpus': True}) @@ -250,6 +299,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) output = input_ids[0] cuda = not any((shared.args.cpu, shared.args.deepspeed)) + if state['auto_max_new_tokens']: + generate_params['max_new_tokens'] = state['truncation_length'] - input_ids.shape[-1] # Add the encoded tokens to generate_params question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None) @@ -264,6 +315,13 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings generate_params['stopping_criteria'] = transformers.StoppingCriteriaList() generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria()) + processor = state.get('logits_processor', LogitsProcessorList([])) + # In case a processor is passed by itself. + if not isinstance(processor, LogitsProcessorList): + processor = LogitsProcessorList([processor]) + apply_extensions('logits_processor', processor, input_ids) + generate_params['logits_processor'] = processor + t0 = time.time() try: if not is_chat and not shared.is_seq2seq: @@ -308,6 +366,9 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings def generate_reply_custom(question, original_question, seed, state, stopping_strings=None, is_chat=False): + """ + For models that do not use the transformers library for sampling + """ seed = set_manual_seed(state['seed']) t0 = time.time() @@ -331,66 +392,3 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str new_tokens = len(encode(original_question + reply)[0]) - original_tokens print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return - - -def generate_reply_flexgen(question, original_question, seed, state, stopping_strings=None, is_chat=False): - generate_params = {} - for k in ['max_new_tokens', 'do_sample', 'temperature']: - generate_params[k] = state[k] - - if state['stream']: - generate_params['max_new_tokens'] = 8 - - # Encode the input - input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) - output = input_ids[0] - - # Find the eos tokens - eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] - if not state['ban_eos_token']: - generate_params['stop'] = eos_token_ids[-1] - - # Add the encoded tokens to generate_params - question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None) - original_input_ids = input_ids - generate_params.update({'inputs': input_ids}) - if inputs_embeds is not None: - generate_params.update({'inputs_embeds': inputs_embeds}) - - t0 = time.time() - try: - if not is_chat: - yield '' - - # Generate the entire reply at once. - if not state['stream']: - with torch.no_grad(): - output = shared.model.generate(**generate_params)[0] - - yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat) - - # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' - else: - for i in range(state['max_new_tokens'] // 8 + 1): - if shared.stop_everything: - break - - clear_torch_cache() - with torch.no_grad(): - output = shared.model.generate(**generate_params)[0] - - if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)): - break - - yield get_reply_from_output_ids(output, original_input_ids, original_question, state) - input_ids = np.reshape(output, (1, output.shape[0])) - generate_params.update({'inputs': input_ids}) - - except Exception: - traceback.print_exc() - finally: - t1 = time.time() - original_tokens = len(original_input_ids[0]) - new_tokens = len(output) - (original_tokens if not shared.is_seq2seq else 0) - print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') - return diff --git a/modules/training.py b/modules/training.py index cdf7c59..3044690 100644 --- a/modules/training.py +++ b/modules/training.py @@ -1,26 +1,34 @@ +import os + +os.environ["WANDB_MODE"] = "offline" +# os.environ["WANDB_DISABLED"] = "true" + import json import math import random +import shutil import sys import threading import time import traceback +from datetime import datetime from pathlib import Path import gradio as gr import torch import transformers - -import shutil -from datetime import datetime - from datasets import Dataset, load_dataset from peft import ( LoraConfig, get_peft_model, - prepare_model_for_int8_training, + prepare_model_for_kbit_training, set_peft_model_state_dict ) +from peft.utils.other import \ + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as model_to_lora_modules +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES +) from modules import shared, ui, utils from modules.evaluate import ( @@ -29,126 +37,130 @@ from modules.evaluate import ( save_past_evaluations ) from modules.logging_colors import logger +from modules.models import reload_model +from modules.utils import natural_keys -# This mapping is from a very recent commit, not yet released. -# If not available, default to a backup map for some common model types. -try: - from peft.utils.other import \ - TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \ - model_to_lora_modules - from transformers.models.auto.modeling_auto import ( - MODEL_FOR_CAUSAL_LM_MAPPING_NAMES - ) - MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES} -except: - standard_modules = ["q_proj", "v_proj"] - model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"], "rw": ["query_key_value"]} - MODEL_CLASSES = { - "LlamaForCausalLM": "llama", - "OPTForCausalLM": "opt", - "GPTJForCausalLM": "gptj", - "GPTNeoXForCausalLM": "gpt_neox", - "RWForCausalLM": "rw" - - } +MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()} +PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"] +WANT_INTERRUPT = False train_log = {} train_template = {} -WANT_INTERRUPT = False -PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss"] - -def create_train_interface(): - with gr.Tab('Train LoRA', elem_id='lora-train-tab'): - gr.Markdown("Confused? [[Click here for a guide]](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md)") - - with gr.Row(): - lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file') - always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (must use the same rank value as the original had).') - save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.') - - with gr.Row(): - copy_from = gr.Dropdown(label='Copy parameters from', value='None', choices=utils.get_available_loras()) - ui.create_refresh_button(copy_from, lambda: None, lambda: {'choices': utils.get_available_loras()}, 'refresh-button') - - with gr.Row(): - # TODO: Implement multi-device support. - micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.') - batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.') - - with gr.Row(): - epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.') - learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.') - lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.') - - # TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale. - lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, higher values like 128 or 256 are good for teaching content upgrades, extremely high values (1024+) are difficult to train but may improve fine-detail learning for large datasets. Higher ranks also require higher VRAM.') - lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.') - - cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.') - - with gr.Tab(label='Formatted Dataset'): +def create_ui(): + with gr.Tab("Training", elem_id="training-tab"): + with gr.Tab('Train LoRA', elem_id='lora-train-tab'): + tmp = gr.State('') with gr.Row(): - dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.') - ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button') - eval_dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.') - ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button') - format = gr.Dropdown(choices=utils.get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.') - ui.create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('training/formats', 'json')}, 'refresh-button') + with gr.Column(): + gr.Markdown("[Tutorial](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md)") - eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.') + with gr.Row(): + copy_from = gr.Dropdown(label='Copy parameters from', value='None', choices=utils.get_available_loras(), elem_classes=['slim-dropdown']) + ui.create_refresh_button(copy_from, lambda: None, lambda: {'choices': utils.get_available_loras()}, 'refresh-button') - with gr.Tab(label="Raw text file"): + with gr.Row(): + with gr.Column(scale=5): + lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file') + with gr.Column(): + always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background']) + + with gr.Row(): + with gr.Column(): + lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.') + lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.') + batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.') + micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.') + cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.') + + with gr.Column(): + save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.') + + epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.') + learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='In scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.') + lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.', elem_classes=['slim-dropdown']) + + with gr.Accordion(label='Advanced Options', open=False): + with gr.Row(): + with gr.Column(): + lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.') + stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)') + optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.', elem_classes=['slim-dropdown']) + + with gr.Column(): + warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.') + train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.') + + add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each dataset item. In case of raw text, the EOS will be added at the Hard Cut") + + higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.') + report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True) + + with gr.Column(): + with gr.Tab(label='Formatted Dataset'): + with gr.Row(): + format = gr.Dropdown(choices=utils.get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.', elem_classes=['slim-dropdown']) + ui.create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('training/formats', 'json')}, 'refresh-button') + + with gr.Row(): + dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.', elem_classes=['slim-dropdown']) + ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button') + + with gr.Row(): + eval_dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown']) + ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button') + + eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.') + + with gr.Tab(label="Raw text file"): + with gr.Row(): + raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.', elem_classes=['slim-dropdown']) + ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button') + + with gr.Row(): + with gr.Column(): + overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='How many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length). Setting overlap to exactly half the cutoff length may be ideal.') + newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.') + + with gr.Column(): + hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.') + min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Hard Cut blocks that have less or equal characters than this number') + + with gr.Row(): + start_button = gr.Button("Start LoRA Training", variant='primary') + stop_button = gr.Button("Interrupt") + + output = gr.Markdown(value="Ready") + + with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'): with gr.Row(): - raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.') - ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button') - hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.') + with gr.Column(): + models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True) + evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.') + with gr.Row(): + with gr.Column(): + stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.') + with gr.Column(): + max_length = gr.Slider(label='max_length', minimum=0, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.') + + with gr.Row(): + start_current_evaluation = gr.Button("Evaluate loaded model") + start_evaluation = gr.Button("Evaluate selected models") + stop_evaluation = gr.Button("Interrupt") + + with gr.Column(): + evaluation_log = gr.Markdown(value='') + + evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True) with gr.Row(): - overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.') - newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.') - - with gr.Accordion(label='Advanced Options', open=False): - lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.') - warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.') - optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.') - train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.') - stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)') - - with gr.Row(): - higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.') - - with gr.Row(): - start_button = gr.Button("Start LoRA Training") - stop_button = gr.Button("Interrupt") - - output = gr.Markdown(value="Ready") - - with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'): - with gr.Row(): - with gr.Column(): - models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True) - evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.') - with gr.Row(): - stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.') - max_length = gr.Slider(label='max_length', minimum=0, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.') - - with gr.Row(): - start_current_evaluation = gr.Button("Evaluate loaded model") - start_evaluation = gr.Button("Evaluate selected models") - stop_evaluation = gr.Button("Interrupt") - - with gr.Column(): - evaluation_log = gr.Markdown(value='') - - evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True) - with gr.Row(): - save_comments = gr.Button('Save comments', elem_classes="small-button") - refresh_table = gr.Button('Refresh the table', elem_classes="small-button") + save_comments = gr.Button('Save comments', elem_classes="small-button") + refresh_table = gr.Button('Refresh the table', elem_classes="small-button") # Training events - all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss] + all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to] + copy_from.change(do_copy_params, [copy_from] + all_params, all_params) start_button.click(do_train, all_params, output) stop_button.click(do_interrupt, None, None, queue=False) @@ -160,7 +172,6 @@ def create_train_interface(): ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False) start_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False) - tmp = gr.State('') start_current_evaluation.click(lambda: ['current model'], None, tmp) ev_cur = start_current_evaluation.click(calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False) start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False) @@ -203,8 +214,6 @@ def change_rank_limit(use_higher_ranks: bool): def clean_path(base_path: str, path: str): """Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" - # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path. - # Or swap it to a strict whitelist of [a-zA-Z_0-9] path = path.replace('\\', '/').replace('..', '_') if base_path is None: return path @@ -223,7 +232,7 @@ def backup_adapter(input_folder): creation_date_str = creation_date.strftime("Backup-%Y-%m-%d") # Create the new subfolder - subfolder_path = Path(f"{input_folder}/{creation_date_str}") + subfolder_path = Path(f"{input_folder}/{creation_date_str}") subfolder_path.mkdir(parents=True, exist_ok=True) # Check if the file already exists in the subfolder @@ -240,6 +249,7 @@ def backup_adapter(input_folder): except Exception as e: print("An error occurred in backup_adapter:", str(e)) + def calc_trainable_parameters(model): trainable_params = 0 all_param = 0 @@ -252,29 +262,29 @@ def calc_trainable_parameters(model): all_param += num_params if param.requires_grad: trainable_params += num_params - - return trainable_params,all_param + + return trainable_params, all_param -def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float): +def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str): if shared.args.monkey_patch: - from monkeypatch.peft_tuners_lora_monkey_patch import ( - replace_peft_model_with_gptq_lora_model + from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( + replace_peft_model_with_int4_lora_model ) - replace_peft_model_with_gptq_lora_model() + replace_peft_model_with_int4_lora_model() global WANT_INTERRUPT WANT_INTERRUPT = False # == Input validation / processing == - yield "Prepping..." + yield "Preparing the input..." lora_file_path = clean_path(None, lora_name) if lora_file_path.strip() == '': yield "Missing or invalid LoRA file name input." return - lora_file_path = f"{shared.args.lora_dir}/{lora_file_path}" + lora_file_path = f"{Path(shared.args.lora_dir)}/{lora_file_path}" actual_lr = float(learning_rate) model_type = type(shared.model).__name__ @@ -295,15 +305,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch time.sleep(5) - if shared.args.wbits > 0 and not shared.args.monkey_patch: - yield "LoRA training with GPTQ models requires loading with `--monkey-patch`" + if shared.args.loader == 'GPTQ-for-LLaMa' and not shared.args.monkey_patch: + yield "LoRA training with GPTQ-for-LLaMa requires loading with `--monkey-patch`" return - elif not (shared.args.load_in_8bit or shared.args.load_in_4bit) and shared.args.wbits <= 0: - yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*" - logger.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.") - time.sleep(2) # Give it a moment for the message to show in UI before continuing - if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0: yield "Cannot input zeroes." return @@ -314,14 +319,22 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch def encode(text, add_bos_token): result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len) + # Check if the first two tokens are BOS + if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]: + result = result[1:] + if not add_bos_token and result[0] == shared.tokenizer.bos_token_id: result = result[1:] return result - def tokenize(prompt): + def tokenize(prompt, append_eos_token=False): if train_only_after == '' or train_only_after not in prompt: input_ids = encode(prompt, True) + + if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len: + input_ids.append(shared.tokenizer.eos_token_id) + input_ids = [shared.tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids labels = [1] * len(input_ids) @@ -330,6 +343,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch before_tokens = encode(prompt[:ind], True) after_tokens = encode(prompt[ind:], False) + if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id: + after_tokens.append(shared.tokenizer.eos_token_id) + full_length = len(after_tokens) + len(before_tokens) if full_length > cutoff_len: after_tokens = after_tokens[:cutoff_len - len(before_tokens)] @@ -350,31 +366,45 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch # == Prep the dataset, format, etc == if raw_text_file not in ['None', '']: - logger.info("Loading raw text file dataset...") - train_template["template_type"] = "raw_text" + logger.info("Loading raw text file dataset...") + fullpath = clean_path('training/datasets', f'{raw_text_file}') + fullpath = Path(fullpath) + if fullpath.is_dir(): + logger.info('Training path directory {}'.format(raw_text_file)) + raw_text = "" + file_paths = sorted(fullpath.glob('*.txt'), key=lambda path: natural_keys(path.name)) + for file_path in file_paths: + if file_path.is_file(): + with file_path.open('r', encoding='utf-8') as file: + raw_text += file.read().replace('\r', '') - with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file: - raw_text = file.read().replace('\r', '') + logger.info(f"Loaded training file: {file_path.name}") + else: + with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file: + raw_text = file.read().replace('\r', '') cut_string = hard_cut_string.replace('\\n', '\n') + eos_added = 0 out_tokens = [] for text_part in raw_text.split(cut_string): - if text_part.strip() == '': + if len(text_part.strip()) <= min_chars: continue tokens = shared.tokenizer.encode(text_part) + if add_eos_token: + tokens.append(shared.tokenizer.eos_token_id) + eos_added += 1 + step = cutoff_len - overlap_len if step <= 0: yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})" return - tokens = list(split_chunks(tokens, step)) - for i in range(1, len(tokens)): - tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i] + out_tokens.extend(split_chunks(tokens, cutoff_len, step)) - out_tokens.extend(tokens) - del tokens + if eos_added > 0: + print(f"EOS added to {eos_added} text blocks") del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM text_chunks = [shared.tokenizer.decode(x) for x in out_tokens] @@ -387,11 +417,11 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch eval_data = None else: if dataset in ['None', '']: - yield "**Missing dataset choice input, cannot continue.**" + yield "Missing dataset choice input, cannot continue." return if format in ['None', '']: - yield "**Missing format choice input, cannot continue.**" + yield "Missing format choice input, cannot continue." return train_template["template_type"] = "dataset" @@ -406,16 +436,16 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch def generate_prompt(data_point: dict[str, str]): for options, data in format_data.items(): - if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] is not None and len(x[1].strip()) > 0)): + if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)): for key, val in data_point.items(): - if val is not None: + if type(val) is str: data = data.replace(f'%{key}%', val) return data raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"') def generate_and_tokenize_prompt(data_point): prompt = generate_prompt(data_point) - return tokenize(prompt) + return tokenize(prompt, add_eos_token) logger.info("Loading JSON datasets...") data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json')) @@ -427,12 +457,33 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json')) eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30)) + # == We MUST reload model if it went through any previous training, even failed one == + if shared.model_dirty_from_training: + selected_model = shared.model_name + if selected_model: + print("\033[1;31;1m(Model has been modified by previous training, it needs to be reloaded...)\033[0;37;0m") + try: + yield f"Reloading {selected_model}..." + reload_model() + if shared.model is not None: + print("Model reloaded OK, continue with training.") + else: + return f"Failed to load {selected_model}." + except: + exc = traceback.format_exc() + logger.error('Failed to reload the model.') + print(exc) + return exc.replace('\n', '\n\n') + # == Start prepping the model itself == if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): logger.info("Getting model ready...") - prepare_model_for_int8_training(shared.model) + prepare_model_for_kbit_training(shared.model) - logger.info("Prepping for training...") + # base model is now frozen and should not be reused for any other LoRA training than this one + shared.model_dirty_from_training = True + + logger.info("Preparing for training...") config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, @@ -457,15 +508,16 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin") set_peft_model_state_dict(lora_model, state_dict_peft) except: - yield traceback.format_exc() + yield traceback.format_exc().replace('\n', '\n\n') return if shared.args.monkey_patch: - for n, m in lora_model.named_modules(): - if '4bit' in str(type(m)): + from alpaca_lora_4bit.autograd_4bit import Autograd4bitQuantLinear + from alpaca_lora_4bit.models import Linear4bitLt + for _, m in lora_model.named_modules(): + if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt): if m.is_v1_model: m.zeros = m.zeros.half() - m.scales = m.scales.half() class Tracked(): @@ -518,6 +570,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch train_dataset=train_data, eval_dataset=eval_data, args=transformers.TrainingArguments( + report_to=report_to if report_to != "None" else None, per_device_train_batch_size=micro_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps), @@ -534,7 +587,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch load_best_model_at_end=eval_data is not None, # TODO: Enable multi-device support ddp_find_unused_parameters=None, - no_cuda=shared.args.cpu + no_cuda=shared.args.cpu, ), data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False), callbacks=list([Callbacks()]) @@ -559,15 +612,19 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch yield "Starting..." lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model) - - if lora_all_param>0: - print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})") + projections_string = ", ".join([projection.replace("_proj", "") for projection in model_to_lora_modules[model_id]]) + + print(f"Training '{model_id}' model using ({projections_string}) projections") + + if lora_all_param > 0: + print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})") train_log.update({"base_model_name": shared.model_name}) train_log.update({"base_model_class": shared.model.__class__.__name__}) train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)}) train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)}) + train_log.update({"projections": projections_string}) if stop_at_loss > 0: print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m") @@ -576,7 +633,26 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch yield "Interrupted before start." return + def log_train_dataset(trainer): + decoded_entries = [] + # Try to decode the entries and write the log file + try: + # Iterate over the first 10 elements in the dataset (or fewer if there are less than 10) + for i in range(min(10, len(trainer.train_dataset))): + decoded_text = shared.tokenizer.decode(trainer.train_dataset[i]['input_ids']) + decoded_entries.append({"value": decoded_text}) + + # Write the log file + Path('logs').mkdir(exist_ok=True) + with open(Path('logs/train_dataset_sample.json'), 'w') as json_file: + json.dump(decoded_entries, json_file, indent=4) + + logger.info("Log file 'train_dataset_sample.json' created in the 'logs' directory.") + except Exception as e: + logger.error(f"Failed to create log file due to error: {e}") + def threaded_run(): + log_train_dataset(trainer) trainer.train() # Note: save in the thread in case the gradio thread breaks (eg browser closed) lora_model.save_pretrained(lora_file_path) @@ -619,15 +695,15 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch if WANT_INTERRUPT: logger.info("Training interrupted.") - yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`" + yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`." else: logger.info("Training complete!") - yield f"Done! LoRA saved to `{lora_file_path}`" + yield f"Done! LoRA saved to `{lora_file_path}`.\n\nBefore testing your new LoRA, make sure to first reload the model, as it is currently dirty from training." -def split_chunks(arr, step): +def split_chunks(arr, size, step): for i in range(0, len(arr), step): - yield arr[i:i + step] + yield arr[i:i + size] def cut_chunk_for_newline(chunk: str, max_length: int): diff --git a/modules/ui.py b/modules/ui.py index 3504953..0a19b23 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1,27 +1,32 @@ -import json +import copy from pathlib import Path import gradio as gr import torch +import yaml from modules import shared -with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f: +with open(Path(__file__).resolve().parent / '../css/NotoSans/stylesheet.css', 'r') as f: css = f.read() -with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f: - chat_css = f.read() -with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f: - main_js = f.read() -with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f: - chat_js = f.read() +with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f: + css += f.read() +with open(Path(__file__).resolve().parent / '../js/main.js', 'r') as f: + js = f.read() +with open(Path(__file__).resolve().parent / '../js/save_files.js', 'r') as f: + save_files_js = f.read() +with open(Path(__file__).resolve().parent / '../js/switch_tabs.js', 'r') as f: + switch_tabs_js = f.read() +with open(Path(__file__).resolve().parent / '../js/show_controls.js', 'r') as f: + show_controls_js = f.read() refresh_symbol = '🔄' delete_symbol = '🗑️' save_symbol = '💾' theme = gr.themes.Default( - font=['Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'], + font=['Noto Sans', 'Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'], font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'], ).set( border_color_primary='#c5c5d2', @@ -30,6 +35,11 @@ theme = gr.themes.Default( background_fill_secondary='#eaeaea' ) +if Path("notification.mp3").exists(): + audio_notification_js = "document.querySelector('#audio_notification audio')?.play();" +else: + audio_notification_js = "" + def list_model_elements(): elements = [ @@ -54,17 +64,23 @@ def list_model_elements(): 'no_inject_fused_attention', 'no_inject_fused_mlp', 'no_use_cuda_fp16', + 'disable_exllama', + 'cfg_cache', 'threads', 'n_batch', 'no_mmap', + 'low_vram', 'mlock', + 'mul_mat_q', 'n_gpu_layers', + 'tensor_split', 'n_ctx', 'llama_cpp_seed', 'gpu_split', 'max_seq_len', 'compress_pos_emb', - 'alpha_value' + 'alpha_value', + 'rope_freq_base' ] for i in range(torch.cuda.device_count()): @@ -76,6 +92,8 @@ def list_model_elements(): def list_interface_input_elements(): elements = [ 'max_new_tokens', + 'auto_max_new_tokens', + 'max_tokens_second', 'seed', 'temperature', 'top_p', @@ -96,8 +114,11 @@ def list_interface_input_elements(): 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'negative_prompt', + 'guidance_scale', 'add_bos_token', 'ban_eos_token', + 'custom_token_bans', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', @@ -106,31 +127,38 @@ def list_interface_input_elements(): 'top_a', ] - if shared.args.chat: - elements += [ - 'character_menu', - 'history', - 'name1', - 'name2', - 'greeting', - 'context', - 'chat_generation_attempts', - 'stop_at_newline', - 'mode', - 'instruction_template', - 'name1_instruct', - 'name2_instruct', - 'context_instruct', - 'turn_template', - 'chat_style', - 'chat-instruct_command', - ] - else: - elements.append('textbox') - if not shared.args.notebook: - elements.append('output_textbox') + # Chat elements + elements += [ + 'textbox', + 'start_with', + 'character_menu', + 'history', + 'name1', + 'name2', + 'greeting', + 'context', + 'mode', + 'instruction_template', + 'name1_instruct', + 'name2_instruct', + 'context_instruct', + 'turn_template', + 'chat_style', + 'chat-instruct_command', + ] + # Notebook/default elements + elements += [ + 'textbox-notebook', + 'textbox-default', + 'output_textbox', + 'prompt_menu-default', + 'prompt_menu-notebook', + ] + + # Model elements elements += list_model_elements() + return elements @@ -141,9 +169,6 @@ def gather_interface_values(*args): if not shared.args.multi_user: shared.persistent_interface_state = output - Path('logs').mkdir(exist_ok=True) - with open(Path(f'logs/session_{shared.get_mode()}_autosave.json'), 'w') as f: - f.write(json.dumps(output, indent=4)) return output @@ -159,8 +184,30 @@ def apply_interface_values(state, use_persistent=False): return [state[k] if k in state else gr.update() for k in elements] -class ToolButton(gr.Button, gr.components.FormComponent): - """Small button with single emoji as text, fits inside gradio forms""" +def save_settings(state, preset, instruction_template, extensions, show_controls): + output = copy.deepcopy(shared.settings) + exclude = ['name2', 'greeting', 'context', 'turn_template'] + for k in state: + if k in shared.settings and k not in exclude: + output[k] = state[k] + + output['preset'] = preset + output['prompt-default'] = state['prompt_menu-default'] + output['prompt-notebook'] = state['prompt_menu-notebook'] + output['character'] = state['character_menu'] + output['instruction_template'] = instruction_template + output['default_extensions'] = extensions + output['seed'] = int(output['seed']) + output['show_controls'] = show_controls + + return yaml.dump(output, sort_keys=False, width=float("inf")) + + +class ToolButton(gr.Button, gr.components.IOComponent): + """ + Small button with single emoji as text, fits inside gradio forms + Copied from https://github.com/AUTOMATIC1111/stable-diffusion-webui + """ def __init__(self, **kwargs): super().__init__(**kwargs) @@ -170,6 +217,9 @@ class ToolButton(gr.Button, gr.components.FormComponent): def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_class): + """ + Copied from https://github.com/AUTOMATIC1111/stable-diffusion-webui + """ def refresh(): refresh_method() args = refreshed_args() if callable(refreshed_args) else refreshed_args diff --git a/modules/ui_chat.py b/modules/ui_chat.py new file mode 100644 index 0000000..4ec93ec --- /dev/null +++ b/modules/ui_chat.py @@ -0,0 +1,349 @@ +import json +from functools import partial +from pathlib import Path + +import gradio as gr +from PIL import Image + +from modules import chat, prompts, shared, ui, utils +from modules.html_generator import chat_html_wrapper +from modules.text_generation import stop_everything_event +from modules.utils import gradio + +inputs = ('Chat input', 'interface_state') +reload_arr = ('history', 'name1', 'name2', 'mode', 'chat_style') +clear_arr = ('delete_chat-confirm', 'delete_chat', 'delete_chat-cancel') + + +def create_ui(): + shared.gradio['Chat input'] = gr.State() + shared.gradio['dummy'] = gr.State() + shared.gradio['history'] = gr.State({'internal': [], 'visible': []}) + + with gr.Tab('Chat', elem_id='chat-tab', elem_classes=("old-ui" if shared.args.chat_buttons else None)): + with gr.Row(): + with gr.Column(elem_id='chat-col'): + shared.gradio['display'] = gr.HTML(value=chat_html_wrapper({'internal': [], 'visible': []}, '', '', 'chat', 'cai-chat')) + + with gr.Row(elem_id="chat-input-row"): + with gr.Column(scale=1, elem_id='gr-hover-container'): + gr.HTML(value='
      ', elem_id='gr-hover') + + with gr.Column(scale=10, elem_id='chat-input-container'): + shared.gradio['textbox'] = gr.Textbox(label='', placeholder='Send a message', elem_id='chat-input', elem_classes=['add_scrollbar']) + shared.gradio['show_controls'] = gr.Checkbox(value=shared.settings['show_controls'], label='Show controls (Ctrl+S)', elem_id='show-controls') + shared.gradio['typing-dots'] = gr.HTML(value='
      ', label='typing', elem_id='typing-container') + + with gr.Column(scale=1, elem_id='generate-stop-container'): + with gr.Row(): + shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop', visible=False) + shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate', variant='primary') + + # Hover menu buttons + with gr.Column(elem_id='chat-buttons'): + with gr.Row(): + shared.gradio['Regenerate'] = gr.Button('Regenerate (Ctrl + Enter)', elem_id='Regenerate') + shared.gradio['Continue'] = gr.Button('Continue (Alt + Enter)', elem_id='Continue') + shared.gradio['Remove last'] = gr.Button('Remove last reply (Ctrl + Shift + Backspace)', elem_id='Remove-last') + + with gr.Row(): + shared.gradio['Replace last reply'] = gr.Button('Replace last reply (Ctrl + Shift + L)', elem_id='Replace-last') + shared.gradio['Copy last reply'] = gr.Button('Copy last reply (Ctrl + Shift + K)', elem_id='Copy-last') + shared.gradio['Impersonate'] = gr.Button('Impersonate (Ctrl + Shift + M)', elem_id='Impersonate') + + with gr.Row(): + shared.gradio['Send dummy message'] = gr.Button('Send dummy message') + shared.gradio['Send dummy reply'] = gr.Button('Send dummy reply') + + with gr.Row(): + shared.gradio['Start new chat'] = gr.Button('Start new chat') + + with gr.Row(): + shared.gradio['send-chat-to-default'] = gr.Button('Send to default') + shared.gradio['send-chat-to-notebook'] = gr.Button('Send to notebook') + + with gr.Row(elem_id='past-chats-row'): + shared.gradio['unique_id'] = gr.Dropdown(label='Past chats', elem_classes=['slim-dropdown']) + shared.gradio['rename_chat'] = gr.Button('Rename', elem_classes='refresh-button') + shared.gradio['delete_chat'] = gr.Button('🗑️', elem_classes='refresh-button') + shared.gradio['delete_chat-cancel'] = gr.Button('Cancel', visible=False, elem_classes='refresh-button') + shared.gradio['delete_chat-confirm'] = gr.Button('Confirm', variant='stop', visible=False, elem_classes='refresh-button') + + with gr.Row(elem_id='rename-row'): + shared.gradio['rename_to'] = gr.Textbox(label='Rename to:', placeholder='New name', visible=False, elem_classes=['no-background']) + shared.gradio['rename_to-cancel'] = gr.Button('Cancel', visible=False, elem_classes='refresh-button') + shared.gradio['rename_to-confirm'] = gr.Button('Confirm', visible=False, elem_classes='refresh-button') + + with gr.Row(): + shared.gradio['start_with'] = gr.Textbox(label='Start reply with', placeholder='Sure thing!', value=shared.settings['start_with']) + + with gr.Row(): + shared.gradio['mode'] = gr.Radio(choices=['chat', 'chat-instruct', 'instruct'], value='chat', label='Mode', info='Defines how the chat prompt is generated. In instruct and chat-instruct modes, the instruction template selected under Parameters > Instruction template must match the current model.', elem_id='chat-mode') + shared.gradio['chat_style'] = gr.Dropdown(choices=utils.get_available_chat_styles(), label='Chat style', value=shared.settings['chat_style'], visible=shared.settings['mode'] != 'instruct') + + +def create_chat_settings_ui(): + with gr.Tab('Character'): + with gr.Row(): + with gr.Column(scale=8): + with gr.Row(): + shared.gradio['character_menu'] = gr.Dropdown(value='', choices=utils.get_available_characters(), label='Character', elem_id='character-menu', info='Used in chat and chat-instruct modes.', elem_classes='slim-dropdown') + ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': utils.get_available_characters()}, 'refresh-button') + shared.gradio['save_character'] = gr.Button('💾', elem_classes='refresh-button') + shared.gradio['delete_character'] = gr.Button('🗑️', elem_classes='refresh-button') + + shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') + shared.gradio['name2'] = gr.Textbox(value='', lines=1, label='Character\'s name') + shared.gradio['context'] = gr.Textbox(value='', lines=10, label='Context', elem_classes=['add_scrollbar']) + shared.gradio['greeting'] = gr.Textbox(value='', lines=5, label='Greeting', elem_classes=['add_scrollbar']) + + with gr.Column(scale=1): + shared.gradio['character_picture'] = gr.Image(label='Character picture', type='pil') + shared.gradio['your_picture'] = gr.Image(label='Your picture', type='pil', value=Image.open(Path('cache/pfp_me.png')) if Path('cache/pfp_me.png').exists() else None) + + with gr.Tab('Instruction template'): + with gr.Row(): + with gr.Row(): + shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Instruction template', value='None', info='Change this according to the model/LoRA that you are using. Used in instruct and chat-instruct modes.', elem_classes='slim-dropdown') + ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button') + shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button') + shared.gradio['delete_template'] = gr.Button('🗑️ ', elem_classes='refresh-button') + + shared.gradio['name1_instruct'] = gr.Textbox(value='', lines=2, label='User string') + shared.gradio['name2_instruct'] = gr.Textbox(value='', lines=1, label='Bot string') + shared.gradio['context_instruct'] = gr.Textbox(value='', lines=4, label='Context', elem_classes=['add_scrollbar']) + shared.gradio['turn_template'] = gr.Textbox(value='', lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.', elem_classes=['add_scrollbar']) + with gr.Row(): + shared.gradio['send_instruction_to_default'] = gr.Button('Send to default', elem_classes=['small-button']) + shared.gradio['send_instruction_to_notebook'] = gr.Button('Send to notebook', elem_classes=['small-button']) + shared.gradio['send_instruction_to_negative_prompt'] = gr.Button('Send to negative prompt', elem_classes=['small-button']) + + with gr.Row(): + shared.gradio['chat-instruct_command'] = gr.Textbox(value=shared.settings['chat-instruct_command'], lines=4, label='Command for chat-instruct mode', info='<|character|> gets replaced by the bot name, and <|prompt|> gets replaced by the regular chat prompt.', elem_classes=['add_scrollbar']) + + with gr.Tab('Chat history'): + with gr.Row(): + with gr.Column(): + shared.gradio['save_chat_history'] = gr.Button(value='Save history') + + with gr.Column(): + shared.gradio['load_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'], label='Upload History JSON') + + with gr.Tab('Upload character'): + with gr.Tab('YAML or JSON'): + with gr.Row(): + shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json', '.yaml'], label='JSON or YAML File') + shared.gradio['upload_img_bot'] = gr.Image(type='pil', label='Profile Picture (optional)') + + shared.gradio['Submit character'] = gr.Button(value='Submit', interactive=False) + + with gr.Tab('TavernAI PNG'): + with gr.Row(): + with gr.Column(): + shared.gradio['upload_img_tavern'] = gr.Image(type='pil', label='TavernAI PNG File', elem_id='upload_img_tavern') + shared.gradio['tavern_json'] = gr.State() + with gr.Column(): + shared.gradio['tavern_name'] = gr.Textbox(value='', lines=1, label='Name', interactive=False) + shared.gradio['tavern_desc'] = gr.Textbox(value='', lines=4, max_lines=4, label='Description', interactive=False) + + shared.gradio['Submit tavern character'] = gr.Button(value='Submit', interactive=False) + + +def create_event_handlers(): + + # Obsolete variables, kept for compatibility with old extensions + shared.input_params = gradio(inputs) + shared.reload_inputs = gradio(reload_arr) + + shared.gradio['Generate'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + lambda x: (x, ''), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False).then( + chat.generate_chat_reply_wrapper, gradio(inputs), gradio('display', 'history'), show_progress=False).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + chat.save_history, gradio('history', 'unique_id', 'character_menu', 'mode'), None).then( + lambda: None, None, None, _js=f'() => {{{ui.audio_notification_js}}}') + + shared.gradio['textbox'].submit( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + lambda x: (x, ''), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False).then( + chat.generate_chat_reply_wrapper, gradio(inputs), gradio('display', 'history'), show_progress=False).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + chat.save_history, gradio('history', 'unique_id', 'character_menu', 'mode'), None).then( + lambda: None, None, None, _js=f'() => {{{ui.audio_notification_js}}}') + + shared.gradio['Regenerate'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + partial(chat.generate_chat_reply_wrapper, regenerate=True), gradio(inputs), gradio('display', 'history'), show_progress=False).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + chat.save_history, gradio('history', 'unique_id', 'character_menu', 'mode'), None).then( + lambda: None, None, None, _js=f'() => {{{ui.audio_notification_js}}}') + + shared.gradio['Continue'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + partial(chat.generate_chat_reply_wrapper, _continue=True), gradio(inputs), gradio('display', 'history'), show_progress=False).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + chat.save_history, gradio('history', 'unique_id', 'character_menu', 'mode'), None).then( + lambda: None, None, None, _js=f'() => {{{ui.audio_notification_js}}}') + + shared.gradio['Impersonate'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + lambda x: x, gradio('textbox'), gradio('Chat input'), show_progress=False).then( + chat.impersonate_wrapper, gradio(inputs), gradio('textbox', 'display'), show_progress=False).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + lambda: None, None, None, _js=f'() => {{{ui.audio_notification_js}}}') + + shared.gradio['Replace last reply'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + chat.replace_last_reply, gradio('textbox', 'interface_state'), gradio('history')).then( + lambda: '', None, gradio('textbox'), show_progress=False).then( + chat.redraw_html, gradio(reload_arr), gradio('display')).then( + chat.save_history, gradio('history', 'unique_id', 'character_menu', 'mode'), None) + + shared.gradio['Send dummy message'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + chat.send_dummy_message, gradio('textbox', 'interface_state'), gradio('history')).then( + lambda: '', None, gradio('textbox'), show_progress=False).then( + chat.redraw_html, gradio(reload_arr), gradio('display')).then( + chat.save_history, gradio('history', 'unique_id', 'character_menu', 'mode'), None) + + shared.gradio['Send dummy reply'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + chat.send_dummy_reply, gradio('textbox', 'interface_state'), gradio('history')).then( + lambda: '', None, gradio('textbox'), show_progress=False).then( + chat.redraw_html, gradio(reload_arr), gradio('display')).then( + chat.save_history, gradio('history', 'unique_id', 'character_menu', 'mode'), None) + + shared.gradio['Remove last'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + chat.remove_last_message, gradio('history'), gradio('textbox', 'history'), show_progress=False).then( + chat.redraw_html, gradio(reload_arr), gradio('display')).then( + chat.save_history, gradio('history', 'unique_id', 'character_menu', 'mode'), None) + + shared.gradio['Stop'].click( + stop_everything_event, None, None, queue=False).then( + chat.redraw_html, gradio(reload_arr), gradio('display')) + + if not shared.args.multi_user: + shared.gradio['unique_id'].select( + chat.load_history, gradio('unique_id', 'character_menu', 'mode'), gradio('history')).then( + chat.redraw_html, gradio(reload_arr), gradio('display')) + + shared.gradio['Start new chat'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + chat.start_new_chat, gradio('interface_state'), gradio('history')).then( + chat.redraw_html, gradio(reload_arr), gradio('display')).then( + lambda x: gr.update(choices=(histories := chat.find_all_histories(x)), value=histories[0]), gradio('interface_state'), gradio('unique_id')) + + shared.gradio['delete_chat'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, gradio(clear_arr)) + shared.gradio['delete_chat-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, gradio(clear_arr)) + shared.gradio['delete_chat-confirm'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + chat.delete_history, gradio('unique_id', 'character_menu', 'mode'), None).then( + chat.load_latest_history, gradio('interface_state'), gradio('history')).then( + chat.redraw_html, gradio(reload_arr), gradio('display')).then( + lambda x: gr.update(choices=(histories := chat.find_all_histories(x)), value=histories[0]), gradio('interface_state'), gradio('unique_id')).then( + lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, gradio(clear_arr)) + + shared.gradio['rename_chat'].click( + lambda x: x, gradio('unique_id'), gradio('rename_to')).then( + lambda: [gr.update(visible=True)] * 3, None, gradio('rename_to', 'rename_to-confirm', 'rename_to-cancel'), show_progress=False) + + shared.gradio['rename_to-cancel'].click( + lambda: [gr.update(visible=False)] * 3, None, gradio('rename_to', 'rename_to-confirm', 'rename_to-cancel'), show_progress=False) + + shared.gradio['rename_to-confirm'].click( + chat.rename_history, gradio('unique_id', 'rename_to', 'character_menu', 'mode'), None).then( + lambda: [gr.update(visible=False)] * 3, None, gradio('rename_to', 'rename_to-confirm', 'rename_to-cancel'), show_progress=False).then( + lambda x, y: gr.update(choices=chat.find_all_histories(x), value=y), gradio('interface_state', 'rename_to'), gradio('unique_id')) + + shared.gradio['load_chat_history'].upload( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + chat.start_new_chat, gradio('interface_state'), gradio('history')).then( + chat.load_history_json, gradio('load_chat_history', 'history'), gradio('history')).then( + chat.redraw_html, gradio(reload_arr), gradio('display')).then( + lambda x: gr.update(choices=(histories := chat.find_all_histories(x)), value=histories[0]), gradio('interface_state'), gradio('unique_id')).then( + chat.save_history, gradio('history', 'unique_id', 'character_menu', 'mode'), None).then( + lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_chat()}}') + + shared.gradio['character_menu'].change( + partial(chat.load_character, instruct=False), gradio('character_menu', 'name1', 'name2'), gradio('name1', 'name2', 'character_picture', 'greeting', 'context', 'dummy')).success( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + chat.load_latest_history, gradio('interface_state'), gradio('history')).then( + chat.redraw_html, gradio(reload_arr), gradio('display')).then( + lambda x: gr.update(choices=(histories := chat.find_all_histories(x)), value=histories[0]), gradio('interface_state'), gradio('unique_id')) + + shared.gradio['mode'].change( + lambda x: gr.update(visible=x != 'instruct'), gradio('mode'), gradio('chat_style'), show_progress=False).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + partial(chat.character_is_loaded, raise_exception=True), gradio('interface_state'), None).success( + chat.load_latest_history, gradio('interface_state'), gradio('history')).then( + chat.redraw_html, gradio(reload_arr), gradio('display')).then( + lambda x: gr.update(choices=(histories := chat.find_all_histories(x)), value=histories[0]), gradio('interface_state'), gradio('unique_id')) + + shared.gradio['chat_style'].change(chat.redraw_html, gradio(reload_arr), gradio('display')) + shared.gradio['instruction_template'].change( + partial(chat.load_character, instruct=True), gradio('instruction_template', 'name1_instruct', 'name2_instruct'), gradio('name1_instruct', 'name2_instruct', 'dummy', 'dummy', 'context_instruct', 'turn_template')) + + shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, gradio('history'), gradio('textbox'), show_progress=False) + + # Save/delete a character + shared.gradio['save_character'].click( + lambda x: x, gradio('name2'), gradio('save_character_filename')).then( + lambda: gr.update(visible=True), None, gradio('character_saver')) + + shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, gradio('character_deleter')) + + shared.gradio['save_template'].click( + lambda: 'My Template.yaml', None, gradio('save_filename')).then( + lambda: 'instruction-templates/', None, gradio('save_root')).then( + chat.generate_instruction_template_yaml, gradio('name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template'), gradio('save_contents')).then( + lambda: gr.update(visible=True), None, gradio('file_saver')) + + shared.gradio['delete_template'].click( + lambda x: f'{x}.yaml', gradio('instruction_template'), gradio('delete_filename')).then( + lambda: 'instruction-templates/', None, gradio('delete_root')).then( + lambda: gr.update(visible=True), None, gradio('file_deleter')) + + shared.gradio['save_chat_history'].click( + lambda x: json.dumps(x, indent=4), gradio('history'), gradio('temporary_text')).then( + None, gradio('temporary_text', 'character_menu', 'mode'), None, _js=f'(hist, char, mode) => {{{ui.save_files_js}; saveHistory(hist, char, mode)}}') + + shared.gradio['Submit character'].click( + chat.upload_character, gradio('upload_json', 'upload_img_bot'), gradio('character_menu')).then( + lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_character()}}') + + shared.gradio['Submit tavern character'].click( + chat.upload_tavern_character, gradio('upload_img_tavern', 'tavern_json'), gradio('character_menu')).then( + lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_character()}}') + + shared.gradio['upload_json'].upload(lambda: gr.update(interactive=True), None, gradio('Submit character')) + shared.gradio['upload_json'].clear(lambda: gr.update(interactive=False), None, gradio('Submit character')) + shared.gradio['upload_img_tavern'].upload(chat.check_tavern_character, gradio('upload_img_tavern'), gradio('tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character'), show_progress=False) + shared.gradio['upload_img_tavern'].clear(lambda: (None, None, None, gr.update(interactive=False)), None, gradio('tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character'), show_progress=False) + shared.gradio['your_picture'].change( + chat.upload_your_profile_picture, gradio('your_picture'), None).then( + partial(chat.redraw_html, reset_cache=True), gradio(reload_arr), gradio('display')) + + shared.gradio['send_instruction_to_default'].click( + prompts.load_instruction_prompt_simple, gradio('instruction_template'), gradio('textbox-default')).then( + lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_default()}}') + + shared.gradio['send_instruction_to_notebook'].click( + prompts.load_instruction_prompt_simple, gradio('instruction_template'), gradio('textbox-notebook')).then( + lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_notebook()}}') + + shared.gradio['send_instruction_to_negative_prompt'].click( + prompts.load_instruction_prompt_simple, gradio('instruction_template'), gradio('negative_prompt')).then( + lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_generation_parameters()}}') + + shared.gradio['send-chat-to-default'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + partial(chat.generate_chat_prompt, '', _continue=True), gradio('interface_state'), gradio('textbox-default')).then( + lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_default()}}') + + shared.gradio['send-chat-to-notebook'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + partial(chat.generate_chat_prompt, '', _continue=True), gradio('interface_state'), gradio('textbox-notebook')).then( + lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_notebook()}}') + + shared.gradio['show_controls'].change(None, gradio('show_controls'), None, _js=f'(x) => {{{ui.show_controls_js}; toggle_controls(x)}}') diff --git a/modules/ui_default.py b/modules/ui_default.py new file mode 100644 index 0000000..7357094 --- /dev/null +++ b/modules/ui_default.py @@ -0,0 +1,103 @@ +import gradio as gr + +from modules import logits, shared, ui, utils +from modules.prompts import count_tokens, load_prompt +from modules.text_generation import ( + generate_reply_wrapper, + get_token_ids, + stop_everything_event +) +from modules.utils import gradio + +inputs = ('textbox-default', 'interface_state') +outputs = ('output_textbox', 'html-default') + + +def create_ui(): + with gr.Tab('Default', elem_id='default-tab'): + shared.gradio['last_input-default'] = gr.State('') + with gr.Row(): + with gr.Column(): + with gr.Row(): + shared.gradio['textbox-default'] = gr.Textbox(value='', lines=27, label='Input', elem_classes=['textbox_default', 'add_scrollbar']) + shared.gradio['token-counter-default'] = gr.HTML(value="0", elem_classes=["token-counter", "default-token-counter"]) + + with gr.Row(): + shared.gradio['Generate-default'] = gr.Button('Generate', variant='primary') + shared.gradio['Stop-default'] = gr.Button('Stop', elem_id='stop') + shared.gradio['Continue-default'] = gr.Button('Continue') + + with gr.Row(): + shared.gradio['prompt_menu-default'] = gr.Dropdown(choices=utils.get_available_prompts(), value='None', label='Prompt', elem_classes='slim-dropdown') + ui.create_refresh_button(shared.gradio['prompt_menu-default'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, 'refresh-button') + shared.gradio['save_prompt-default'] = gr.Button('💾', elem_classes='refresh-button') + shared.gradio['delete_prompt-default'] = gr.Button('🗑️', elem_classes='refresh-button') + + with gr.Column(): + with gr.Tab('Raw'): + shared.gradio['output_textbox'] = gr.Textbox(lines=27, label='Output', elem_id='textbox-default', elem_classes=['textbox_default_output', 'add_scrollbar']) + + with gr.Tab('Markdown'): + shared.gradio['markdown_render-default'] = gr.Button('Render') + shared.gradio['markdown-default'] = gr.Markdown() + + with gr.Tab('HTML'): + shared.gradio['html-default'] = gr.HTML() + + with gr.Tab('Logits'): + with gr.Row(): + with gr.Column(scale=10): + shared.gradio['get_logits-default'] = gr.Button('Get next token probabilities') + with gr.Column(scale=1): + shared.gradio['use_samplers-default'] = gr.Checkbox(label='Use samplers', value=True, elem_classes=['no-background']) + + with gr.Row(): + shared.gradio['logits-default'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar']) + shared.gradio['logits-default-previous'] = gr.Textbox(lines=23, label='Previous output', elem_classes=['textbox_logits', 'add_scrollbar']) + + with gr.Tab('Tokens'): + shared.gradio['get_tokens-default'] = gr.Button('Get token IDs for the input') + shared.gradio['tokens-default'] = gr.Textbox(lines=23, label='Tokens', elem_classes=['textbox_logits', 'add_scrollbar', 'monospace']) + + +def create_event_handlers(): + shared.gradio['Generate-default'].click( + lambda x: x, gradio('textbox-default'), gradio('last_input-default')).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + generate_reply_wrapper, gradio(inputs), gradio(outputs), show_progress=False).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + lambda: None, None, None, _js=f'() => {{{ui.audio_notification_js}}}') + + shared.gradio['textbox-default'].submit( + lambda x: x, gradio('textbox-default'), gradio('last_input-default')).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + generate_reply_wrapper, gradio(inputs), gradio(outputs), show_progress=False).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + lambda: None, None, None, _js=f'() => {{{ui.audio_notification_js}}}') + + shared.gradio['markdown_render-default'].click(lambda x: x, gradio('output_textbox'), gradio('markdown-default'), queue=False) + shared.gradio['Continue-default'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + generate_reply_wrapper, [shared.gradio['output_textbox']] + gradio(inputs)[1:], gradio(outputs), show_progress=False).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + lambda: None, None, None, _js=f'() => {{{ui.audio_notification_js}}}') + + shared.gradio['Stop-default'].click(stop_everything_event, None, None, queue=False) + shared.gradio['prompt_menu-default'].change(load_prompt, gradio('prompt_menu-default'), gradio('textbox-default'), show_progress=False) + shared.gradio['save_prompt-default'].click( + lambda x: x, gradio('textbox-default'), gradio('save_contents')).then( + lambda: 'prompts/', None, gradio('save_root')).then( + lambda: utils.current_time() + '.txt', None, gradio('save_filename')).then( + lambda: gr.update(visible=True), None, gradio('file_saver')) + + shared.gradio['delete_prompt-default'].click( + lambda: 'prompts/', None, gradio('delete_root')).then( + lambda x: x + '.txt', gradio('prompt_menu-default'), gradio('delete_filename')).then( + lambda: gr.update(visible=True), None, gradio('file_deleter')) + + shared.gradio['textbox-default'].change(lambda x: f"{count_tokens(x)}", gradio('textbox-default'), gradio('token-counter-default'), show_progress=False) + shared.gradio['get_logits-default'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + logits.get_next_logits, gradio('textbox-default', 'interface_state', 'use_samplers-default', 'logits-default'), gradio('logits-default', 'logits-default-previous'), show_progress=False) + + shared.gradio['get_tokens-default'].click(get_token_ids, gradio('textbox-default'), gradio('tokens-default'), show_progress=False) diff --git a/modules/ui_file_saving.py b/modules/ui_file_saving.py new file mode 100644 index 0000000..d80378a --- /dev/null +++ b/modules/ui_file_saving.py @@ -0,0 +1,75 @@ +import gradio as gr + +from modules import chat, presets, shared, ui, utils +from modules.utils import gradio + + +def create_ui(): + + # Text file saver + with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['file_saver']: + shared.gradio['save_filename'] = gr.Textbox(lines=1, label='File name') + shared.gradio['save_root'] = gr.Textbox(lines=1, label='File folder', info='For reference. Unchangeable.', interactive=False) + shared.gradio['save_contents'] = gr.Textbox(lines=10, label='File contents') + with gr.Row(): + shared.gradio['save_confirm'] = gr.Button('Save', elem_classes="small-button") + shared.gradio['save_cancel'] = gr.Button('Cancel', elem_classes="small-button") + + # Text file deleter + with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['file_deleter']: + shared.gradio['delete_filename'] = gr.Textbox(lines=1, label='File name') + shared.gradio['delete_root'] = gr.Textbox(lines=1, label='File folder', info='For reference. Unchangeable.', interactive=False) + with gr.Row(): + shared.gradio['delete_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop') + shared.gradio['delete_cancel'] = gr.Button('Cancel', elem_classes="small-button") + + # Character saver/deleter + with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['character_saver']: + shared.gradio['save_character_filename'] = gr.Textbox(lines=1, label='File name', info='The character will be saved to your characters/ folder with this base filename.') + with gr.Row(): + shared.gradio['save_character_confirm'] = gr.Button('Save', elem_classes="small-button") + shared.gradio['save_character_cancel'] = gr.Button('Cancel', elem_classes="small-button") + + with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['character_deleter']: + gr.Markdown('Confirm the character deletion?') + with gr.Row(): + shared.gradio['delete_character_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop') + shared.gradio['delete_character_cancel'] = gr.Button('Cancel', elem_classes="small-button") + + +def create_event_handlers(): + shared.gradio['save_confirm'].click( + lambda x, y, z: utils.save_file(x + y, z), gradio('save_root', 'save_filename', 'save_contents'), None).then( + lambda: gr.update(visible=False), None, gradio('file_saver')) + + shared.gradio['delete_confirm'].click( + lambda x, y: utils.delete_file(x + y), gradio('delete_root', 'delete_filename'), None).then( + lambda: gr.update(visible=False), None, gradio('file_deleter')) + + shared.gradio['delete_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_deleter')) + shared.gradio['save_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_saver')) + + shared.gradio['save_character_confirm'].click( + chat.save_character, gradio('name2', 'greeting', 'context', 'character_picture', 'save_character_filename'), None).then( + lambda: gr.update(visible=False), None, gradio('character_saver')).then( + lambda x: gr.update(choices=utils.get_available_characters(), value=x), gradio('save_character_filename'), gradio('character_menu')) + + shared.gradio['delete_character_confirm'].click( + chat.delete_character, gradio('character_menu'), None).then( + lambda: gr.update(visible=False), None, gradio('character_deleter')).then( + lambda: gr.update(choices=(characters := utils.get_available_characters()), value=characters[0]), None, gradio('character_menu')) + + shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_saver')) + shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_deleter')) + + shared.gradio['save_preset'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + presets.generate_preset_yaml, gradio('interface_state'), gradio('save_contents')).then( + lambda: 'presets/', None, gradio('save_root')).then( + lambda: 'My Preset.yaml', None, gradio('save_filename')).then( + lambda: gr.update(visible=True), None, gradio('file_saver')) + + shared.gradio['delete_preset'].click( + lambda x: f'{x}.yaml', gradio('preset_menu'), gradio('delete_filename')).then( + lambda: 'presets/', None, gradio('delete_root')).then( + lambda: gr.update(visible=True), None, gradio('file_deleter')) diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py new file mode 100644 index 0000000..f965d80 --- /dev/null +++ b/modules/ui_model_menu.py @@ -0,0 +1,259 @@ +import importlib +import math +import re +import traceback +from functools import partial +from pathlib import Path + +import gradio as gr +import psutil +import torch + +from modules import loaders, shared, ui, utils +from modules.logging_colors import logger +from modules.LoRA import add_lora_to_model +from modules.models import load_model, unload_model +from modules.models_settings import ( + apply_model_settings_to_state, + get_model_metadata, + save_model_settings, + update_model_parameters +) +from modules.utils import gradio + + +def create_ui(): + # Finding the default values for the GPU and CPU memories + total_mem = [] + for i in range(torch.cuda.device_count()): + total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024))) + + default_gpu_mem = [] + if shared.args.gpu_memory is not None and len(shared.args.gpu_memory) > 0: + for i in shared.args.gpu_memory: + if 'mib' in i.lower(): + default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i))) + else: + default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)) * 1000) + + while len(default_gpu_mem) < len(total_mem): + default_gpu_mem.append(0) + + total_cpu_mem = math.floor(psutil.virtual_memory().total / (1024 * 1024)) + if shared.args.cpu_memory is not None: + default_cpu_mem = re.sub('[a-zA-Z ]', '', shared.args.cpu_memory) + else: + default_cpu_mem = 0 + + with gr.Tab("Model", elem_id="model-tab"): + with gr.Row(): + with gr.Column(): + with gr.Row(): + with gr.Column(): + with gr.Row(): + shared.gradio['model_menu'] = gr.Dropdown(choices=utils.get_available_models(), value=shared.model_name, label='Model', elem_classes='slim-dropdown') + ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button') + shared.gradio['load_model'] = gr.Button("Load", visible=not shared.settings['autoload_model'], elem_classes='refresh-button') + shared.gradio['unload_model'] = gr.Button("Unload", elem_classes='refresh-button') + shared.gradio['reload_model'] = gr.Button("Reload", elem_classes='refresh-button') + shared.gradio['save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button') + + with gr.Column(): + with gr.Row(): + shared.gradio['lora_menu'] = gr.Dropdown(multiselect=True, choices=utils.get_available_loras(), value=shared.lora_names, label='LoRA(s)', elem_classes='slim-dropdown') + ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': utils.get_available_loras(), 'value': shared.lora_names}, 'refresh-button') + shared.gradio['lora_menu_apply'] = gr.Button(value='Apply LoRAs', elem_classes='refresh-button') + + with gr.Row(): + with gr.Column(): + shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=loaders.loaders_and_params.keys(), value=None) + with gr.Box(): + with gr.Row(): + with gr.Column(): + for i in range(len(total_mem)): + shared.gradio[f'gpu_memory_{i}'] = gr.Slider(label=f"gpu-memory in MiB for device :{i}", maximum=total_mem[i], value=default_gpu_mem[i]) + + shared.gradio['cpu_memory'] = gr.Slider(label="cpu-memory in MiB", maximum=total_cpu_mem, value=default_cpu_mem) + shared.gradio['transformers_info'] = gr.Markdown('load-in-4bit params:') + shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype) + shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type) + + shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=128, value=shared.args.n_gpu_layers) + shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=16384, step=256, label="n_ctx", value=shared.args.n_ctx) + shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=32, value=shared.args.threads) + shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, value=shared.args.n_batch) + + shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=str(shared.args.wbits) if shared.args.wbits > 0 else "None") + shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None") + shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None"], value=shared.args.model_type or "None") + shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0) + shared.gradio['autogptq_info'] = gr.Markdown('* ExLlama_HF is recommended over AutoGPTQ for models derived from LLaMA.') + shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7') + shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=0, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len) + shared.gradio['alpha_value'] = gr.Slider(label='alpha_value', minimum=1, maximum=8, step=0.1, info='Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both.', value=shared.args.alpha_value) + shared.gradio['rope_freq_base'] = gr.Slider(label='rope_freq_base', minimum=0, maximum=1000000, step=1000, info='If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63)', value=shared.args.rope_freq_base) + shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8, step=1, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.', value=shared.args.compress_pos_emb) + + with gr.Column(): + shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton) + shared.gradio['no_inject_fused_attention'] = gr.Checkbox(label="no_inject_fused_attention", value=shared.args.no_inject_fused_attention, info='Disable fused attention. Fused attention improves inference performance but uses more VRAM. Disable if running low on VRAM.') + shared.gradio['no_inject_fused_mlp'] = gr.Checkbox(label="no_inject_fused_mlp", value=shared.args.no_inject_fused_mlp, info='Affects Triton only. Disable fused MLP. Fused MLP improves performance but uses more VRAM. Disable if running low on VRAM.') + shared.gradio['no_use_cuda_fp16'] = gr.Checkbox(label="no_use_cuda_fp16", value=shared.args.no_use_cuda_fp16, info='This can make models faster on some systems.') + shared.gradio['desc_act'] = gr.Checkbox(label="desc_act", value=shared.args.desc_act, info='\'desc_act\', \'wbits\', and \'groupsize\' are used for old models without a quantize_config.json.') + shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel, which can improve inference speed on some systems.') + shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu) + shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit) + shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16) + shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices) + shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk) + shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit) + shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant) + shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap) + shared.gradio['low_vram'] = gr.Checkbox(label="low-vram", value=shared.args.low_vram) + shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock) + shared.gradio['mul_mat_q'] = gr.Checkbox(label="mul_mat_q", value=shared.args.mul_mat_q, info='Recommended in most cases. Improves generation speed by 10-20%.') + shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Create an additional cache for CFG negative prompts.') + shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='Split the model across multiple GPUs, comma-separated list of proportions, e.g. 18,17') + shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed) + shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.') + shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa support is currently only kept for compatibility with older GPUs. AutoGPTQ or ExLlama is preferred when compatible. GPTQ-for-LLaMa is installed by default with the webui on supported systems. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).') + shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).') + shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.') + shared.gradio['llamacpp_HF_info'] = gr.Markdown('llamacpp_HF loads llama.cpp as a Transformers model. To use it, you need to download a tokenizer.\n\nOption 1: download `oobabooga/llama-tokenizer` under "Download model or LoRA". That\'s a default Llama tokenizer.\n\nOption 2: place your .gguf in a subfolder of models/ along with these 3 files: tokenizer.model, tokenizer_config.json, and special_tokens_map.json. This takes precedence over Option 1.') + + with gr.Column(): + with gr.Row(): + shared.gradio['autoload_model'] = gr.Checkbox(value=shared.settings['autoload_model'], label='Autoload the model', info='Whether to load the model as soon as it is selected in the Model dropdown.') + + shared.gradio['custom_model_menu'] = gr.Textbox(label="Download model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main. To download a single file, enter its name in the second box.") + shared.gradio['download_specific_file'] = gr.Textbox(placeholder="File name (for GGUF models)", show_label=False, max_lines=1) + with gr.Row(): + shared.gradio['download_model_button'] = gr.Button("Download", variant='primary') + shared.gradio['get_file_list'] = gr.Button("Get file list") + + with gr.Row(): + shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready') + + +def create_event_handlers(): + shared.gradio['loader'].change( + loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params())).then( + lambda value: gr.update(choices=loaders.get_model_types(value)), gradio('loader'), gradio('model_type')) + + # In this event handler, the interface state is read and updated + # with the model defaults (if any), and then the model is loaded + # unless "autoload_model" is unchecked + shared.gradio['model_menu'].change( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + apply_model_settings_to_state, gradio('model_menu', 'interface_state'), gradio('interface_state')).then( + ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then( + update_model_parameters, gradio('interface_state'), None).then( + load_model_wrapper, gradio('model_menu', 'loader', 'autoload_model'), gradio('model_status'), show_progress=False).success( + update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length')) + + shared.gradio['load_model'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + update_model_parameters, gradio('interface_state'), None).then( + partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False).success( + update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length')) + + shared.gradio['unload_model'].click( + unload_model, None, None).then( + lambda: "Model unloaded", None, gradio('model_status')) + + shared.gradio['reload_model'].click( + unload_model, None, None).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + update_model_parameters, gradio('interface_state'), None).then( + partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False).success( + update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length')) + + shared.gradio['save_model_settings'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + save_model_settings, gradio('model_menu', 'interface_state'), gradio('model_status'), show_progress=False) + + shared.gradio['lora_menu_apply'].click(load_lora_wrapper, gradio('lora_menu'), gradio('model_status'), show_progress=False) + shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True) + shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True) + shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), gradio('autoload_model'), gradio('load_model')) + + +def load_model_wrapper(selected_model, loader, autoload=False): + if not autoload: + yield f"The settings for `{selected_model}` have been updated.\n\nClick on \"Load\" to load it." + return + + if selected_model == 'None': + yield "No model selected" + else: + try: + yield f"Loading `{selected_model}`..." + shared.model_name = selected_model + unload_model() + if selected_model != '': + shared.model, shared.tokenizer = load_model(shared.model_name, loader) + + if shared.model is not None: + output = f"Successfully loaded `{selected_model}`." + + settings = get_model_metadata(selected_model) + if 'instruction_template' in settings: + output += '\n\nIt seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.'.format(settings['instruction_template']) + + yield output + else: + yield f"Failed to load `{selected_model}`." + except: + exc = traceback.format_exc() + logger.error('Failed to load the model.') + print(exc) + yield exc.replace('\n', '\n\n') + + +def load_lora_wrapper(selected_loras): + yield ("Applying the following LoRAs to {}:\n\n{}".format(shared.model_name, '\n'.join(selected_loras))) + add_lora_to_model(selected_loras) + yield ("Successfuly applied the LoRAs") + + +def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False, check=False): + try: + downloader_module = importlib.import_module("download-model") + downloader = downloader_module.ModelDownloader() + + progress(0.0) + yield ("Cleaning up the model/branch names") + model, branch = downloader.sanitize_model_and_branch_names(repo_id, None) + + yield ("Getting the download links from Hugging Face") + links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=False, specific_file=specific_file) + + if return_links: + yield '\n\n'.join([f"`{Path(link).name}`" for link in links]) + return + + yield ("Getting the output folder") + base_folder = shared.args.lora_dir if is_lora else shared.args.model_dir + output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, base_folder=base_folder) + + if check: + progress(0.5) + yield ("Checking previously downloaded files") + downloader.check_model_files(model, branch, links, sha256, output_folder) + progress(1.0) + else: + yield (f"Downloading file{'s' if len(links) > 1 else ''} to `{output_folder}/`") + downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=1, is_llamacpp=is_llamacpp) + yield ("Done!") + except: + progress(1.0) + yield traceback.format_exc().replace('\n', '\n\n') + + +def update_truncation_length(current_length, state): + if state['loader'] in ['ExLlama', 'ExLlama_HF']: + return state['max_seq_len'] + elif state['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']: + return state['n_ctx'] + else: + return current_length diff --git a/modules/ui_notebook.py b/modules/ui_notebook.py new file mode 100644 index 0000000..60e3ee4 --- /dev/null +++ b/modules/ui_notebook.py @@ -0,0 +1,105 @@ +import gradio as gr + +from modules import logits, shared, ui, utils +from modules.prompts import count_tokens, load_prompt +from modules.text_generation import ( + generate_reply_wrapper, + get_token_ids, + stop_everything_event +) +from modules.utils import gradio + +inputs = ('textbox-notebook', 'interface_state') +outputs = ('textbox-notebook', 'html-notebook') + + +def create_ui(): + with gr.Tab('Notebook', elem_id='notebook-tab'): + shared.gradio['last_input-notebook'] = gr.State('') + with gr.Row(): + with gr.Column(scale=4): + with gr.Tab('Raw'): + with gr.Row(): + shared.gradio['textbox-notebook'] = gr.Textbox(value='', lines=27, elem_id='textbox-notebook', elem_classes=['textbox', 'add_scrollbar']) + shared.gradio['token-counter-notebook'] = gr.HTML(value="0", elem_classes=["token-counter"]) + + with gr.Tab('Markdown'): + shared.gradio['markdown_render-notebook'] = gr.Button('Render') + shared.gradio['markdown-notebook'] = gr.Markdown() + + with gr.Tab('HTML'): + shared.gradio['html-notebook'] = gr.HTML() + + with gr.Tab('Logits'): + with gr.Row(): + with gr.Column(scale=10): + shared.gradio['get_logits-notebook'] = gr.Button('Get next token probabilities') + with gr.Column(scale=1): + shared.gradio['use_samplers-notebook'] = gr.Checkbox(label='Use samplers', value=True, elem_classes=['no-background']) + + with gr.Row(): + shared.gradio['logits-notebook'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits_notebook', 'add_scrollbar']) + shared.gradio['logits-notebook-previous'] = gr.Textbox(lines=23, label='Previous output', elem_classes=['textbox_logits_notebook', 'add_scrollbar']) + + with gr.Tab('Tokens'): + shared.gradio['get_tokens-notebook'] = gr.Button('Get token IDs for the input') + shared.gradio['tokens-notebook'] = gr.Textbox(lines=23, label='Tokens', elem_classes=['textbox_logits_notebook', 'add_scrollbar', 'monospace']) + + with gr.Row(): + shared.gradio['Generate-notebook'] = gr.Button('Generate', variant='primary', elem_classes='small-button') + shared.gradio['Stop-notebook'] = gr.Button('Stop', elem_classes='small-button', elem_id='stop') + shared.gradio['Undo'] = gr.Button('Undo', elem_classes='small-button') + shared.gradio['Regenerate-notebook'] = gr.Button('Regenerate', elem_classes='small-button') + + with gr.Column(scale=1): + gr.HTML('
      ') + with gr.Row(): + shared.gradio['prompt_menu-notebook'] = gr.Dropdown(choices=utils.get_available_prompts(), value='None', label='Prompt', elem_classes='slim-dropdown') + ui.create_refresh_button(shared.gradio['prompt_menu-notebook'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, ['refresh-button', 'refresh-button-small']) + shared.gradio['save_prompt-notebook'] = gr.Button('💾', elem_classes=['refresh-button', 'refresh-button-small']) + shared.gradio['delete_prompt-notebook'] = gr.Button('🗑️', elem_classes=['refresh-button', 'refresh-button-small']) + + +def create_event_handlers(): + shared.gradio['Generate-notebook'].click( + lambda x: x, gradio('textbox-notebook'), gradio('last_input-notebook')).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + generate_reply_wrapper, gradio(inputs), gradio(outputs), show_progress=False).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + lambda: None, None, None, _js=f'() => {{{ui.audio_notification_js}}}') + + shared.gradio['textbox-notebook'].submit( + lambda x: x, gradio('textbox-notebook'), gradio('last_input-notebook')).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + generate_reply_wrapper, gradio(inputs), gradio(outputs), show_progress=False).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + lambda: None, None, None, _js=f'() => {{{ui.audio_notification_js}}}') + + shared.gradio['Undo'].click(lambda x: x, gradio('last_input-notebook'), gradio('textbox-notebook'), show_progress=False) + shared.gradio['markdown_render-notebook'].click(lambda x: x, gradio('textbox-notebook'), gradio('markdown-notebook'), queue=False) + shared.gradio['Regenerate-notebook'].click( + lambda x: x, gradio('last_input-notebook'), gradio('textbox-notebook'), show_progress=False).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + generate_reply_wrapper, gradio(inputs), gradio(outputs), show_progress=False).then( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + lambda: None, None, None, _js=f'() => {{{ui.audio_notification_js}}}') + + shared.gradio['Stop-notebook'].click(stop_everything_event, None, None, queue=False) + shared.gradio['prompt_menu-notebook'].change(load_prompt, gradio('prompt_menu-notebook'), gradio('textbox-notebook'), show_progress=False) + shared.gradio['save_prompt-notebook'].click( + lambda x: x, gradio('textbox-notebook'), gradio('save_contents')).then( + lambda: 'prompts/', None, gradio('save_root')).then( + lambda: utils.current_time() + '.txt', None, gradio('save_filename')).then( + lambda: gr.update(visible=True), None, gradio('file_saver')) + + shared.gradio['delete_prompt-notebook'].click( + lambda: 'prompts/', None, gradio('delete_root')).then( + lambda x: x + '.txt', gradio('prompt_menu-notebook'), gradio('delete_filename')).then( + lambda: gr.update(visible=True), None, gradio('file_deleter')) + + shared.gradio['textbox-notebook'].input(lambda x: f"{count_tokens(x)}", gradio('textbox-notebook'), gradio('token-counter-notebook'), show_progress=False) + shared.gradio['get_logits-notebook'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + logits.get_next_logits, gradio('textbox-notebook', 'interface_state', 'use_samplers-notebook', 'logits-notebook'), gradio('logits-notebook', 'logits-notebook-previous'), show_progress=False) + + shared.gradio['get_tokens-notebook'].click(get_token_ids, gradio('textbox-notebook'), gradio('tokens-notebook'), show_progress=False) diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py new file mode 100644 index 0000000..9fbe645 --- /dev/null +++ b/modules/ui_parameters.py @@ -0,0 +1,140 @@ +import gradio as gr + +from modules import loaders, presets, shared, ui, ui_chat, utils +from modules.utils import gradio + + +def create_ui(default_preset): + generate_params = presets.load_preset(default_preset) + with gr.Tab("Parameters", elem_id="parameters"): + with gr.Tab("Generation"): + with gr.Row(): + with gr.Column(): + with gr.Row(): + shared.gradio['preset_menu'] = gr.Dropdown(choices=utils.get_available_presets(), value=default_preset, label='Preset', elem_classes='slim-dropdown') + ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': utils.get_available_presets()}, 'refresh-button') + shared.gradio['save_preset'] = gr.Button('💾', elem_classes='refresh-button') + shared.gradio['delete_preset'] = gr.Button('🗑️', elem_classes='refresh-button') + + with gr.Column(): + shared.gradio['filter_by_loader'] = gr.Dropdown(label="Filter by loader", choices=["All"] + list(loaders.loaders_and_params.keys()), value="All", elem_classes='slim-dropdown') + + with gr.Row(): + with gr.Column(): + with gr.Box(): + with gr.Row(): + with gr.Column(): + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature') + shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p') + shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k') + shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p') + shared.gradio['epsilon_cutoff'] = gr.Slider(0, 9, value=generate_params['epsilon_cutoff'], step=0.01, label='epsilon_cutoff') + shared.gradio['eta_cutoff'] = gr.Slider(0, 20, value=generate_params['eta_cutoff'], step=0.01, label='eta_cutoff') + shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=generate_params['tfs'], step=0.01, label='tfs') + shared.gradio['top_a'] = gr.Slider(0.0, 1.0, value=generate_params['top_a'], step=0.01, label='top_a') + + with gr.Column(): + shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty') + shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range') + shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty') + shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size') + shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length') + shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)') + shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') + + with gr.Accordion("Learn more", open=False): + gr.Markdown(""" + + For a technical description of the parameters, the [transformers documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig) is a good reference. + + The best presets, according to the [Preset Arena](https://github.com/oobabooga/oobabooga.github.io/blob/main/arena/results.md) experiment, are: + + * Instruction following: + 1) Divine Intellect + 2) Big O + 3) simple-1 + 4) Space Alien + 5) StarChat + 6) Titanic + 7) tfs-with-top-a + 8) Asterism + 9) Contrastive Search + + * Chat: + 1) Midnight Enigma + 2) Yara + 3) Shortwave + + ### Temperature + Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness. + ### top_p + If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results. + ### top_k + Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results. + ### typical_p + If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text. + ### epsilon_cutoff + In units of 1e-4; a reasonable value is 3. This sets a probability floor below which tokens are excluded from being sampled. Should be used with top_p, top_k, and eta_cutoff set to 0. + ### eta_cutoff + In units of 1e-4; a reasonable value is 3. Should be used with top_p, top_k, and epsilon_cutoff set to 0. + ### repetition_penalty + Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition. + ### repetition_penalty_range + The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used. + ### encoder_repetition_penalty + Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge. + ### no_repeat_ngram_size + If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases. + ### min_length + Minimum generation length in tokens. + ### penalty_alpha + Contrastive Search is enabled by setting this to greater than zero and unchecking "do_sample". It should be used with a low value of top_k, for instance, top_k = 4. + + """, elem_classes="markdown") + + with gr.Column(): + with gr.Box(): + with gr.Row(): + with gr.Column(): + shared.gradio['guidance_scale'] = gr.Slider(-0.5, 2.5, step=0.05, value=generate_params['guidance_scale'], label='guidance_scale', info='For CFG. 1.5 is a good value.') + shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt', lines=3, elem_classes=['add_scrollbar']) + shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.') + shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau') + shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta') + + with gr.Column(): + shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='For Contrastive Search. do_sample must be unchecked.') + shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.') + shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') + shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') + + with gr.Box(): + with gr.Row(): + with gr.Column(): + shared.gradio['truncation_length'] = gr.Slider(value=get_truncation_length(), minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') + shared.gradio['max_tokens_second'] = gr.Slider(value=shared.settings['max_tokens_second'], minimum=0, maximum=20, step=1, label='Maximum number of tokens/second', info='To make text readable in real time.') + shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"') + with gr.Column(): + shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.') + shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.') + shared.gradio['custom_token_bans'] = gr.Textbox(value=shared.settings['custom_token_bans'] or None, label='Custom token bans', info='Specific token IDs to ban from generating, comma-separated. The IDs can be found in the Default or Notebook tab.') + shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') + shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') + shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming') + + ui_chat.create_chat_settings_ui() + + +def create_event_handlers(): + shared.gradio['filter_by_loader'].change(loaders.blacklist_samplers, gradio('filter_by_loader'), gradio(loaders.list_all_samplers()), show_progress=False) + shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state') + gradio(presets.presets_params())) + + +def get_truncation_length(): + if shared.args.max_seq_len != shared.args_defaults.max_seq_len: + return shared.args.max_seq_len + if shared.args.n_ctx != shared.args_defaults.n_ctx: + return shared.args.n_ctx + else: + return shared.settings['truncation_length'] diff --git a/modules/ui_session.py b/modules/ui_session.py new file mode 100644 index 0000000..53d9ec3 --- /dev/null +++ b/modules/ui_session.py @@ -0,0 +1,69 @@ +import gradio as gr + +from modules import shared, ui, utils +from modules.github import clone_or_pull_repository +from modules.utils import gradio + + +def create_ui(): + with gr.Tab("Session", elem_id="session-tab"): + with gr.Row(): + with gr.Column(): + shared.gradio['reset_interface'] = gr.Button("Apply flags/extensions and restart") + with gr.Row(): + shared.gradio['toggle_dark_mode'] = gr.Button('Toggle 💡') + shared.gradio['save_settings'] = gr.Button('Save UI defaults to settings.yaml') + + with gr.Row(): + with gr.Column(): + shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=utils.get_available_extensions(), value=shared.args.extensions, label="Available extensions", info='Note that some of these extensions may require manually installing Python requirements through the command: pip install -r extensions/extension_name/requirements.txt', elem_classes='checkboxgroup-table') + + with gr.Column(): + shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=get_boolean_arguments(), value=get_boolean_arguments(active=True), label="Boolean command-line flags", elem_classes='checkboxgroup-table') + + with gr.Column(): + extension_name = gr.Textbox(lines=1, label='Install or update an extension', info='Enter the GitHub URL below and press Enter. For a list of extensions, see: https://github.com/oobabooga/text-generation-webui-extensions ⚠️ WARNING ⚠️ : extensions can execute arbitrary code. Make sure to inspect their source code before activating them.') + extension_status = gr.Markdown() + + extension_name.submit( + clone_or_pull_repository, extension_name, extension_status, show_progress=False).then( + lambda: gr.update(choices=utils.get_available_extensions(), value=shared.args.extensions), None, gradio('extensions_menu')) + + # Reset interface event + shared.gradio['reset_interface'].click( + set_interface_arguments, gradio('extensions_menu', 'bool_menu'), None).then( + lambda: None, None, None, _js='() => {document.body.innerHTML=\'

      Reloading...

      \'; setTimeout(function(){location.reload()},2500); return []}') + + shared.gradio['toggle_dark_mode'].click(lambda: None, None, None, _js='() => {document.getElementsByTagName("body")[0].classList.toggle("dark")}') + shared.gradio['save_settings'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + ui.save_settings, gradio('interface_state', 'preset_menu', 'instruction_template', 'extensions_menu', 'show_controls'), gradio('save_contents')).then( + lambda: './', None, gradio('save_root')).then( + lambda: 'settings.yaml', None, gradio('save_filename')).then( + lambda: gr.update(visible=True), None, gradio('file_saver')) + + +def set_interface_arguments(extensions, bool_active): + shared.args.extensions = extensions + + bool_list = get_boolean_arguments() + + for k in bool_list: + setattr(shared.args, k, False) + for k in bool_active: + setattr(shared.args, k, True) + + shared.need_restart = True + + +def get_boolean_arguments(active=False): + exclude = ["default", "notebook", "chat"] + + cmd_list = vars(shared.args) + bool_list = sorted([k for k in cmd_list if type(cmd_list[k]) is bool and k not in exclude + ui.list_model_elements()]) + bool_active = [k for k in bool_list if vars(shared.args)[k]] + + if active: + return bool_active + else: + return bool_list diff --git a/modules/utils.py b/modules/utils.py index 72a0dfa..f60597a 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -9,7 +9,7 @@ from modules.logging_colors import logger # Helper function to get multiple values from shared.gradio def gradio(*keys): - if len(keys) == 1 and type(keys[0]) is list: + if len(keys) == 1 and type(keys[0]) in [list, tuple]: keys = keys[0] return [shared.gradio[k] for k in keys] @@ -71,10 +71,12 @@ def natural_keys(text): def get_available_models(): - if shared.args.flexgen: - return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=natural_keys) - else: - return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json', '.yaml'))], key=natural_keys) + model_list = [] + for item in list(Path(f'{shared.args.model_dir}/').glob('*')): + if not item.name.endswith(('.txt', '-np', '.pt', '.json', '.yaml', '.py')) and 'llama-tokenizer' not in item.name: + model_list.append(re.sub('.pth$', '', item.name)) + + return sorted(model_list, key=natural_keys) def get_available_presets(): @@ -86,18 +88,17 @@ def get_available_prompts(): files = set((k.stem for k in Path('prompts').glob('*.txt'))) prompts += sorted([k for k in files if re.match('^[0-9]', k)], key=natural_keys, reverse=True) prompts += sorted([k for k in files if re.match('^[^0-9]', k)], key=natural_keys) - prompts += ['Instruct-' + k for k in get_available_instruction_templates() if k != 'None'] prompts += ['None'] return prompts def get_available_characters(): paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) - return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=natural_keys) + return sorted(set((k.stem for k in paths)), key=natural_keys) def get_available_instruction_templates(): - path = "characters/instruction-following" + path = "instruction-templates" paths = [] if os.path.exists(path): paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml')) @@ -114,13 +115,12 @@ def get_available_loras(): def get_datasets(path: str, ext: str): + # include subdirectories for raw txt files to allow training from a subdirectory of txt files + if ext == "txt": + return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt')) + list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys) + return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys) def get_available_chat_styles(): return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys) - - -def get_available_sessions(): - items = sorted(set(k.stem for k in Path('logs').glob(f'session_{shared.get_mode()}*')), key=natural_keys, reverse=True) - return [item for item in items if 'autosave' in item] + [item for item in items if 'autosave' not in item] diff --git a/presets/Kobold-Godlike.yaml b/presets/Kobold-Godlike.yaml deleted file mode 100644 index 772a802..0000000 --- a/presets/Kobold-Godlike.yaml +++ /dev/null @@ -1,4 +0,0 @@ -temperature: 0.7 -top_p: 0.5 -typical_p: 0.19 -repetition_penalty: 1.1 diff --git a/requirements.txt b/requirements.txt index e553024..4edf6df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,28 +1,53 @@ -accelerate==0.20.3 +aiofiles==23.1.0 +fastapi==0.95.2 +gradio_client==0.2.5 +gradio==3.33.1 +pydantic==1.10.12 + +accelerate==0.23.* colorama datasets einops -fastapi==0.95.2 -flexgen==0.1.7 -gradio_client==0.2.5 -gradio==3.33.1 +exllamav2==0.0.3 markdown -numpy +numpy==1.24 +optimum==1.13.1 pandas +peft==0.5.* Pillow>=9.5.0 pyyaml requests -safetensors==0.3.1 -sentencepiece -tqdm +safetensors==0.3.2 +transformers==4.33.* scipy -transformers==4.30.2 -git+https://github.com/huggingface/peft@03eb378eb914fbee709ff7c86ba5b1d033b89524 -bitsandbytes==0.39.1; platform_system != "Windows" -https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl; platform_system == "Windows" -llama-cpp-python==0.1.69; platform_system != "Windows" -https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.69/llama_cpp_python-0.1.69-cp310-cp310-win_amd64.whl; platform_system == "Windows" -https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" -https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" -https://github.com/jllllll/exllama/releases/download/0.0.5/exllama-0.0.5+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" -https://github.com/jllllll/exllama/releases/download/0.0.5/exllama-0.0.5+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +sentencepiece +tensorboard +tqdm +wandb + +# bitsandbytes +bitsandbytes==0.41.1; platform_system != "Windows" +https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows" + +# AutoGPTQ +https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.2/auto_gptq-0.4.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" +https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.2/auto_gptq-0.4.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" + +# ExLlama +https://github.com/jllllll/exllama/releases/download/0.0.17/exllama-0.0.17+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" +https://github.com/jllllll/exllama/releases/download/0.0.17/exllama-0.0.17+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" + +# llama-cpp-python without GPU support +llama-cpp-python==0.2.6; platform_system != "Windows" +https://github.com/abetlen/llama-cpp-python/releases/download/v0.2.6/llama_cpp_python-0.2.6-cp310-cp310-win_amd64.whl; platform_system == "Windows" + +# llama-cpp-python with CUDA support +https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.6+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" +https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.6+cu117-cp310-cp310-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" + +# GPTQ-for-LLaMa +https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.0/gptq_for_llama-0.1.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" +https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.0/gptq_for_llama-0.1.0+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" + +# ctransformers +https://github.com/jllllll/ctransformers-cuBLAS-wheels/releases/download/AVX2/ctransformers-0.2.27+cu117-py3-none-any.whl diff --git a/requirements_nocuda.txt b/requirements_nocuda.txt new file mode 100644 index 0000000..51a1e97 --- /dev/null +++ b/requirements_nocuda.txt @@ -0,0 +1,30 @@ +aiofiles==23.1.0 +fastapi==0.95.2 +gradio_client==0.2.5 +gradio==3.33.1 +pydantic==1.10.12 + +accelerate==0.23.* +colorama +datasets +einops +exllamav2==0.0.3 +markdown +numpy==1.24 +optimum==1.13.1 +pandas +peft==0.5.* +Pillow>=9.5.0 +pyyaml +requests +safetensors==0.3.2 +transformers==4.33.* +scipy +sentencepiece +tensorboard +tqdm +wandb + +# bitsandbytes +bitsandbytes==0.41.1; platform_system != "Windows" +https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows" diff --git a/server.py b/server.py index 18474d3..fc99ef7 100644 --- a/server.py +++ b/server.py @@ -1,8 +1,8 @@ import os import warnings -from modules.logging_colors import logger from modules.block_requests import OpenMonkeyPatch, RequestBlocker +from modules.logging_colors import logger os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' os.environ['BITSANDBYTES_NOWELCOME'] = '1' @@ -12,1071 +12,150 @@ with RequestBlocker(): import gradio as gr import matplotlib + matplotlib.use('Agg') # This fixes LaTeX rendering on some systems -import importlib import json -import math import os -import re import sys import time -import traceback from functools import partial from pathlib import Path from threading import Lock -import psutil -import torch import yaml -from PIL import Image import modules.extensions as extensions_module -from modules import chat, loaders, presets, shared, training, ui, utils -from modules.extensions import apply_extensions -from modules.github import clone_or_pull_repository -from modules.html_generator import chat_html_wrapper -from modules.LoRA import add_lora_to_model -from modules.models import load_model, unload_model -from modules.models_settings import ( - apply_model_settings_to_state, - get_model_settings_from_yamls, - save_model_settings, - update_model_parameters +from modules import ( + chat, + shared, + training, + ui, + ui_chat, + ui_default, + ui_file_saving, + ui_model_menu, + ui_notebook, + ui_parameters, + ui_session, + utils ) -from modules.text_generation import ( - generate_reply_wrapper, - get_encoded_length, - stop_everything_event +from modules.extensions import apply_extensions +from modules.LoRA import add_lora_to_model +from modules.models import load_model +from modules.models_settings import ( + get_fallback_settings, + get_model_metadata, + update_model_parameters ) from modules.utils import gradio -def load_model_wrapper(selected_model, loader, autoload=False): - if not autoload: - yield f"The settings for {selected_model} have been updated.\nClick on \"Load the model\" to load it." - return - - if selected_model == 'None': - yield "No model selected" - else: - try: - yield f"Loading {selected_model}..." - shared.model_name = selected_model - unload_model() - if selected_model != '': - shared.model, shared.tokenizer = load_model(shared.model_name, loader) - - if shared.model is not None: - yield f"Successfully loaded {selected_model}" - else: - yield f"Failed to load {selected_model}." - except: - exc = traceback.format_exc() - logger.error('Failed to load the model.') - print(exc) - yield exc - - -def load_lora_wrapper(selected_loras): - yield ("Applying the following LoRAs to {}:\n\n{}".format(shared.model_name, '\n'.join(selected_loras))) - add_lora_to_model(selected_loras) - yield ("Successfuly applied the LoRAs") - - -def load_prompt(fname): - if fname in ['None', '']: - return '' - elif fname.startswith('Instruct-'): - fname = re.sub('^Instruct-', '', fname) - file_path = Path(f'characters/instruction-following/{fname}.yaml') - if not file_path.exists(): - return '' - - with open(file_path, 'r', encoding='utf-8') as f: - data = yaml.safe_load(f) - output = '' - if 'context' in data: - output += data['context'] - - replacements = { - '<|user|>': data['user'], - '<|bot|>': data['bot'], - '<|user-message|>': 'Input', - } - - output += utils.replace_all(data['turn_template'].split('<|bot-message|>')[0], replacements) - return output.rstrip(' ') - else: - file_path = Path(f'prompts/{fname}.txt') - if not file_path.exists(): - return '' - - with open(file_path, 'r', encoding='utf-8') as f: - text = f.read() - if text[-1] == '\n': - text = text[:-1] - - return text - - -def count_tokens(text): - try: - tokens = get_encoded_length(text) - return f'{tokens} tokens in the input.' - except: - return 'Couldn\'t count the number of tokens. Is a tokenizer loaded?' - - -def download_model_wrapper(repo_id, progress=gr.Progress()): - try: - downloader_module = importlib.import_module("download-model") - downloader = downloader_module.ModelDownloader() - repo_id_parts = repo_id.split(":") - model = repo_id_parts[0] if len(repo_id_parts) > 0 else repo_id - branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main" - check = False - - progress(0.0) - yield ("Cleaning up the model/branch names") - model, branch = downloader.sanitize_model_and_branch_names(model, branch) - - yield ("Getting the download links from Hugging Face") - links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False) - - yield ("Getting the output folder") - output_folder = downloader.get_output_folder(model, branch, is_lora) - - if check: - progress(0.5) - yield ("Checking previously downloaded files") - downloader.check_model_files(model, branch, links, sha256, output_folder) - progress(1.0) - else: - yield (f"Downloading files to {output_folder}") - downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=1) - yield ("Done!") - except: - progress(1.0) - yield traceback.format_exc() - - -def create_model_menus(): - # Finding the default values for the GPU and CPU memories - total_mem = [] - for i in range(torch.cuda.device_count()): - total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024))) - - default_gpu_mem = [] - if shared.args.gpu_memory is not None and len(shared.args.gpu_memory) > 0: - for i in shared.args.gpu_memory: - if 'mib' in i.lower(): - default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i))) - else: - default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)) * 1000) - while len(default_gpu_mem) < len(total_mem): - default_gpu_mem.append(0) - - total_cpu_mem = math.floor(psutil.virtual_memory().total / (1024 * 1024)) - if shared.args.cpu_memory is not None: - default_cpu_mem = re.sub('[a-zA-Z ]', '', shared.args.cpu_memory) - else: - default_cpu_mem = 0 - - with gr.Row(): - with gr.Column(): - with gr.Row(): - with gr.Column(): - with gr.Row(): - shared.gradio['model_menu'] = gr.Dropdown(choices=utils.get_available_models(), value=shared.model_name, label='Model', elem_classes='slim-dropdown') - ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button') - load = gr.Button("Load", visible=not shared.settings['autoload_model'], elem_classes='refresh-button') - unload = gr.Button("Unload", elem_classes='refresh-button') - reload = gr.Button("Reload", elem_classes='refresh-button') - save_settings = gr.Button("Save settings", elem_classes='refresh-button') - - with gr.Column(): - with gr.Row(): - shared.gradio['lora_menu'] = gr.Dropdown(multiselect=True, choices=utils.get_available_loras(), value=shared.lora_names, label='LoRA(s)', elem_classes='slim-dropdown') - ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': utils.get_available_loras(), 'value': shared.lora_names}, 'refresh-button') - shared.gradio['lora_menu_apply'] = gr.Button(value='Apply LoRAs', elem_classes='refresh-button') - - with gr.Row(): - with gr.Column(): - shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "ExLlama", "ExLlama_HF", "llama.cpp"], value=None) - with gr.Box(): - with gr.Row(): - with gr.Column(): - for i in range(len(total_mem)): - shared.gradio[f'gpu_memory_{i}'] = gr.Slider(label=f"gpu-memory in MiB for device :{i}", maximum=total_mem[i], value=default_gpu_mem[i]) - - shared.gradio['cpu_memory'] = gr.Slider(label="cpu-memory in MiB", maximum=total_cpu_mem, value=default_cpu_mem) - shared.gradio['transformers_info'] = gr.Markdown('load-in-4bit params:') - shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype) - shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type) - shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=32, value=shared.args.threads) - shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, value=shared.args.n_batch) - shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=128, value=shared.args.n_gpu_layers) - shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=16384, step=256, label="n_ctx", value=shared.args.n_ctx) - shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=shared.args.wbits if shared.args.wbits > 0 else "None") - shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=shared.args.groupsize if shared.args.groupsize > 0 else "None") - shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None") - shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0) - shared.gradio['autogptq_info'] = gr.Markdown('On some systems, AutoGPTQ can be 2x slower than GPTQ-for-LLaMa. You can manually select the GPTQ-for-LLaMa loader above.') - shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7') - shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=2048, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len) - shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8, step=1, info='Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.', value=shared.args.compress_pos_emb) - shared.gradio['alpha_value'] = gr.Slider(label='alpha_value', minimum=1, maximum=8, step=1, info='Positional embeddings alpha factor for NTK RoPE scaling. Same as above. Use either this or compress_pos_emb, not both.', value=shared.args.alpha_value) - - with gr.Column(): - shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton) - shared.gradio['no_inject_fused_attention'] = gr.Checkbox(label="no_inject_fused_attention", value=shared.args.no_inject_fused_attention, info='Disable fused attention. Fused attention improves inference performance but uses more VRAM. Disable if running low on VRAM.') - shared.gradio['no_inject_fused_mlp'] = gr.Checkbox(label="no_inject_fused_mlp", value=shared.args.no_inject_fused_mlp, info='Affects Triton only. Disable fused MLP. Fused MLP improves performance but uses more VRAM. Disable if running low on VRAM.') - shared.gradio['no_use_cuda_fp16'] = gr.Checkbox(label="no_use_cuda_fp16", value=shared.args.no_use_cuda_fp16, info='This can make models faster on some systems.') - shared.gradio['desc_act'] = gr.Checkbox(label="desc_act", value=shared.args.desc_act, info='\'desc_act\', \'wbits\', and \'groupsize\' are used for old models without a quantize_config.json.') - shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu) - shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit) - shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16) - shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices) - shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk) - shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit) - shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant) - shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap) - shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock) - shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed) - shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.') - shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa is currently 2x faster than AutoGPTQ on some systems. It is installed by default with the one-click installers. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).') - shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).') - shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.') - - with gr.Column(): - with gr.Row(): - shared.gradio['autoload_model'] = gr.Checkbox(value=shared.settings['autoload_model'], label='Autoload the model', info='Whether to load the model as soon as it is selected in the Model dropdown.') - - shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main") - shared.gradio['download_model_button'] = gr.Button("Download") - - with gr.Row(): - shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready') - - shared.gradio['loader'].change(loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params())) - - # In this event handler, the interface state is read and updated - # with the model defaults (if any), and then the model is loaded - # unless "autoload_model" is unchecked - shared.gradio['model_menu'].change( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - apply_model_settings_to_state, gradio('model_menu', 'interface_state'), gradio('interface_state')).then( - ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then( - update_model_parameters, gradio('interface_state'), None).then( - load_model_wrapper, gradio('model_menu', 'loader', 'autoload_model'), gradio('model_status'), show_progress=False) - - load.click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - update_model_parameters, gradio('interface_state'), None).then( - partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False).then( - lambda: shared.lora_names, None, gradio('lora_menu')) - - unload.click( - unload_model, None, None).then( - lambda: "Model unloaded", None, gradio('model_status')).then( - lambda: shared.lora_names, None, gradio('lora_menu')) - - reload.click( - unload_model, None, None).then( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - update_model_parameters, gradio('interface_state'), None).then( - partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False).then( - lambda: shared.lora_names, None, gradio('lora_menu')) - - save_settings.click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - save_model_settings, gradio('model_menu', 'interface_state'), gradio('model_status'), show_progress=False) - - shared.gradio['lora_menu_apply'].click(load_lora_wrapper, gradio('lora_menu'), gradio('model_status'), show_progress=False) - shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu'), gradio('model_status'), show_progress=True) - shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), gradio('autoload_model'), load) - - -def create_chat_settings_menus(): - if not shared.is_chat(): - return - - with gr.Box(): - gr.Markdown("Chat parameters") - with gr.Row(): - with gr.Column(): - shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)', info='New generations will be called until either this number is reached or no new content is generated between two iterations.') - - with gr.Column(): - shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character') - - -def create_settings_menus(default_preset): - generate_params = presets.load_preset(default_preset) - with gr.Row(): - with gr.Column(): - with gr.Row(): - shared.gradio['preset_menu'] = gr.Dropdown(choices=utils.get_available_presets(), value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset', elem_classes='slim-dropdown') - ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': utils.get_available_presets()}, 'refresh-button') - shared.gradio['save_preset'] = gr.Button('💾', elem_classes='refresh-button') - shared.gradio['delete_preset'] = gr.Button('🗑️', elem_classes='refresh-button') - - with gr.Column(): - shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)') - - with gr.Row(): - with gr.Column(): - with gr.Box(): - gr.Markdown('Main parameters') - with gr.Row(): - with gr.Column(): - shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature') - shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p') - shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k') - shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p') - shared.gradio['epsilon_cutoff'] = gr.Slider(0, 9, value=generate_params['epsilon_cutoff'], step=0.01, label='epsilon_cutoff') - shared.gradio['eta_cutoff'] = gr.Slider(0, 20, value=generate_params['eta_cutoff'], step=0.01, label='eta_cutoff') - - with gr.Column(): - shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty') - shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range') - shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty') - shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size') - shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length') - shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=generate_params['tfs'], step=0.01, label='tfs') - shared.gradio['top_a'] = gr.Slider(0.0, 1.0, value=generate_params['top_a'], step=0.01, label='top_a') - shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') - - with gr.Accordion("Learn more", open=False): - gr.Markdown(""" - - Not all parameters are used by all loaders. See [this page](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Generation-parameters.md) for details. - - For a technical description of the parameters, the [transformers documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig) is a good reference. - - The best presets, according to the [Preset Arena](https://github.com/oobabooga/oobabooga.github.io/blob/main/arena/results.md) experiment, are: - - * Instruction following: - 1) Divine Intellect - 2) Big O - 3) simple-1 - 4) Space Alien - 5) StarChat - 6) Titanic - 7) tfs-with-top-a - 8) Asterism - 9) Contrastive Search - - * Chat: - 1) Midnight Enigma - 2) Yara - 3) Shortwave - 4) Kobold-Godlike - - ### Temperature - Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness. - ### top_p - If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results. - ### top_k - Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results. - ### typical_p - If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text. - ### epsilon_cutoff - In units of 1e-4; a reasonable value is 3. This sets a probability floor below which tokens are excluded from being sampled. Should be used with top_p, top_k, and eta_cutoff set to 0. - ### eta_cutoff - In units of 1e-4; a reasonable value is 3. Should be used with top_p, top_k, and epsilon_cutoff set to 0. - ### repetition_penalty - Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition. - ### repetition_penalty_range - The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used. - ### encoder_repetition_penalty - Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge. - ### no_repeat_ngram_size - If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases. - ### min_length - Minimum generation length in tokens. - ### penalty_alpha - Contrastive Search is enabled by setting this to greater than zero and unchecking "do_sample". It should be used with a low value of top_k, for instance, top_k = 4. - - """, elem_classes="markdown") - - with gr.Column(): - create_chat_settings_menus() - with gr.Box(): - with gr.Row(): - with gr.Column(): - gr.Markdown('Contrastive search') - shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') - - gr.Markdown('Beam search') - shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams') - shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') - shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') - - with gr.Column(): - gr.Markdown('Mirostat (mode=1 is only for llama.cpp)') - shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode') - shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau') - shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta') - - with gr.Box(): - with gr.Row(): - with gr.Column(): - shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') - shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"') - with gr.Column(): - shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.') - shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') - - shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') - shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming') - - shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a')) - - -def create_file_saving_menus(): - - # Text file saver - with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['file_saver']: - shared.gradio['save_filename'] = gr.Textbox(lines=1, label='File name') - shared.gradio['save_root'] = gr.Textbox(lines=1, label='File folder', info='For reference. Unchangeable.', interactive=False) - shared.gradio['save_contents'] = gr.Textbox(lines=10, label='File contents') - with gr.Row(): - shared.gradio['save_confirm'] = gr.Button('Save', elem_classes="small-button") - shared.gradio['save_cancel'] = gr.Button('Cancel', elem_classes="small-button") - - # Text file deleter - with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['file_deleter']: - shared.gradio['delete_filename'] = gr.Textbox(lines=1, label='File name') - shared.gradio['delete_root'] = gr.Textbox(lines=1, label='File folder', info='For reference. Unchangeable.', interactive=False) - with gr.Row(): - shared.gradio['delete_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop') - shared.gradio['delete_cancel'] = gr.Button('Cancel', elem_classes="small-button") - - # Character saver/deleter - if shared.is_chat(): - with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['character_saver']: - shared.gradio['save_character_filename'] = gr.Textbox(lines=1, label='File name', info='The character will be saved to your characters/ folder with this base filename.') - with gr.Row(): - shared.gradio['save_character_confirm'] = gr.Button('Save', elem_classes="small-button") - shared.gradio['save_character_cancel'] = gr.Button('Cancel', elem_classes="small-button") - - with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['character_deleter']: - gr.Markdown('Confirm the character deletion?') - with gr.Row(): - shared.gradio['delete_character_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop') - shared.gradio['delete_character_cancel'] = gr.Button('Cancel', elem_classes="small-button") - - -def create_file_saving_event_handlers(): - shared.gradio['save_confirm'].click( - lambda x, y, z: utils.save_file(x + y, z), gradio('save_root', 'save_filename', 'save_contents'), None).then( - lambda: gr.update(visible=False), None, gradio('file_saver')) - - shared.gradio['delete_confirm'].click( - lambda x, y: utils.delete_file(x + y), gradio('delete_root', 'delete_filename'), None).then( - lambda: gr.update(visible=False), None, gradio('file_deleter')) - - shared.gradio['delete_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_deleter')) - shared.gradio['save_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_saver')) - if shared.is_chat(): - shared.gradio['save_character_confirm'].click( - chat.save_character, gradio('name2', 'greeting', 'context', 'character_picture', 'save_character_filename'), None).then( - lambda: gr.update(visible=False), None, gradio('character_saver')) - - shared.gradio['delete_character_confirm'].click( - chat.delete_character, gradio('character_menu'), None).then( - lambda: gr.update(visible=False), None, gradio('character_deleter')).then( - lambda: gr.update(choices=utils.get_available_characters()), None, gradio('character_menu')) - - shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_saver')) - shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_deleter')) - - shared.gradio['save_preset'].click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - presets.generate_preset_yaml, gradio('interface_state'), gradio('save_contents')).then( - lambda: 'presets/', None, gradio('save_root')).then( - lambda: 'My Preset.yaml', None, gradio('save_filename')).then( - lambda: gr.update(visible=True), None, gradio('file_saver')) - - shared.gradio['delete_preset'].click( - lambda x: f'{x}.yaml', gradio('preset_menu'), gradio('delete_filename')).then( - lambda: 'presets/', None, gradio('delete_root')).then( - lambda: gr.update(visible=True), None, gradio('file_deleter')) - - if not shared.args.multi_user: - - def load_session(session, state): - with open(Path(f'logs/{session}.json'), 'r') as f: - state.update(json.loads(f.read())) - - if shared.is_chat(): - chat.save_persistent_history(state['history'], state['character_menu'], state['mode']) - - return state - - if shared.is_chat(): - shared.gradio['save_session'].click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - lambda x: json.dumps(x, indent=4), gradio('interface_state'), gradio('save_contents')).then( - lambda: 'logs/', None, gradio('save_root')).then( - lambda x: f'session_{shared.get_mode()}_{x + "_" if x not in ["None", None, ""] else ""}{utils.current_time()}.json', gradio('character_menu'), gradio('save_filename')).then( - lambda: gr.update(visible=True), None, gradio('file_saver')) - - shared.gradio['session_menu'].change( - load_session, gradio('session_menu', 'interface_state'), gradio('interface_state')).then( - ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then( - chat.redraw_html, shared.reload_inputs, gradio('display')) - - else: - shared.gradio['save_session'].click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - lambda x: json.dumps(x, indent=4), gradio('interface_state'), gradio('save_contents')).then( - lambda: 'logs/', None, gradio('save_root')).then( - lambda: f'session_{shared.get_mode()}_{utils.current_time()}.json', None, gradio('save_filename')).then( - lambda: gr.update(visible=True), None, gradio('file_saver')) - - shared.gradio['session_menu'].change( - load_session, gradio('session_menu', 'interface_state'), gradio('interface_state')).then( - ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False) - - shared.gradio['delete_session'].click( - lambda x: f'{x}.json', gradio('session_menu'), gradio('delete_filename')).then( - lambda: 'logs/', None, gradio('delete_root')).then( - lambda: gr.update(visible=True), None, gradio('file_deleter')) - - -def set_interface_arguments(interface_mode, extensions, bool_active): - modes = ["default", "notebook", "chat", "cai_chat"] - cmd_list = vars(shared.args) - bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes] - - shared.args.extensions = extensions - for k in modes[1:]: - setattr(shared.args, k, False) - if interface_mode != "default": - setattr(shared.args, interface_mode, True) - - for k in bool_list: - setattr(shared.args, k, False) - for k in bool_active: - setattr(shared.args, k, True) - - shared.need_restart = True - - def create_interface(): - # Defining some variables - gen_events = [] - default_preset = shared.settings['preset'] - default_text = load_prompt(shared.settings['prompt']) title = 'Text generation web UI' - # Authentication variables - auth = None - gradio_auth_creds = [] + # Password authentication + auth = [] if shared.args.gradio_auth: - gradio_auth_creds += [x.strip() for x in shared.args.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()] - if shared.args.gradio_auth_path is not None: + auth.extend(x.strip() for x in shared.args.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()) + if shared.args.gradio_auth_path: with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file: - for line in file.readlines(): - gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()] - if gradio_auth_creds: - auth = [tuple(cred.split(':')) for cred in gradio_auth_creds] + auth.extend(x.strip() for line in file for x in line.split(',') if x.strip()) + auth = [tuple(cred.split(':')) for cred in auth] - # Importing the extension files and executing their setup() functions + # Import the extensions and execute their setup() functions if shared.args.extensions is not None and len(shared.args.extensions) > 0: extensions_module.load_extensions() + # Force some events to be triggered on page load + shared.persistent_interface_state.update({ + 'loader': shared.args.loader or 'Transformers', + 'mode': shared.settings['mode'], + 'character_menu': shared.args.character or shared.settings['character'], + 'instruction_template': shared.settings['instruction_template'], + 'prompt_menu-default': shared.settings['prompt-default'], + 'prompt_menu-notebook': shared.settings['prompt-notebook'], + }) + + if Path("cache/pfp_character.png").exists(): + Path("cache/pfp_character.png").unlink() + # css/js strings - css = ui.css if not shared.is_chat() else ui.css + ui.chat_css - js = ui.main_js if not shared.is_chat() else ui.main_js + ui.chat_js + css = ui.css + js = ui.js css += apply_extensions('css') js += apply_extensions('js') + # Interface state elements + shared.input_elements = ui.list_interface_input_elements() + with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme) as shared.gradio['interface']: + + # Interface state + shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) + + # Audio notification if Path("notification.mp3").exists(): shared.gradio['audio_notification'] = gr.Audio(interactive=False, value="notification.mp3", elem_id="audio_notification", visible=False) - audio_notification_js = "document.querySelector('#audio_notification audio')?.play();" - else: - audio_notification_js = "" # Floating menus for saving/deleting files - create_file_saving_menus() + ui_file_saving.create_ui() - # Create chat mode interface - if shared.is_chat(): - shared.input_elements = ui.list_interface_input_elements() + # Temporary clipboard for saving files + shared.gradio['temporary_text'] = gr.Textbox(visible=False) - shared.gradio.update({ - 'interface_state': gr.State({k: None for k in shared.input_elements}), - 'Chat input': gr.State(), - 'dummy': gr.State(), - 'history': gr.State({'internal': [], 'visible': []}), - }) + # Text Generation tab + ui_chat.create_ui() + ui_default.create_ui() + ui_notebook.create_ui() - with gr.Tab('Text generation', elem_id='main'): - shared.gradio['display'] = gr.HTML(value=chat_html_wrapper({'internal': [], 'visible': []}, shared.settings['name1'], shared.settings['name2'], 'chat', 'cai-chat')) - shared.gradio['textbox'] = gr.Textbox(label='Input') - with gr.Row(): - shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop') - shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate', variant='primary') - shared.gradio['Continue'] = gr.Button('Continue') + ui_parameters.create_ui(shared.settings['preset']) # Parameters tab + ui_model_menu.create_ui() # Model tab + training.create_ui() # Training tab + ui_session.create_ui() # Session tab - with gr.Row(): - shared.gradio['Impersonate'] = gr.Button('Impersonate') - shared.gradio['Regenerate'] = gr.Button('Regenerate') - shared.gradio['Remove last'] = gr.Button('Remove last') + # Generation events + ui_chat.create_event_handlers() + ui_default.create_event_handlers() + ui_notebook.create_event_handlers() - with gr.Row(): - shared.gradio['Copy last reply'] = gr.Button('Copy last reply') - shared.gradio['Replace last reply'] = gr.Button('Replace last reply') - shared.gradio['Send dummy message'] = gr.Button('Send dummy message') - shared.gradio['Send dummy reply'] = gr.Button('Send dummy reply') + # Other events + ui_file_saving.create_event_handlers() + ui_parameters.create_event_handlers() + ui_model_menu.create_event_handlers() - with gr.Row(): - shared.gradio['Clear history'] = gr.Button('Clear history') - shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant='stop', visible=False) - shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) - - with gr.Row(): - shared.gradio['start_with'] = gr.Textbox(label='Start reply with', placeholder='Sure thing!', value=shared.settings['start_with']) - - with gr.Row(): - shared.gradio['mode'] = gr.Radio(choices=['chat', 'chat-instruct', 'instruct'], value=shared.settings['mode'] if shared.settings['mode'] in ['chat', 'instruct', 'chat-instruct'] else 'chat', label='Mode', info='Defines how the chat prompt is generated. In instruct and chat-instruct modes, the instruction template selected under "Chat settings" must match the current model.') - shared.gradio['chat_style'] = gr.Dropdown(choices=utils.get_available_chat_styles(), label='Chat style', value=shared.settings['chat_style'], visible=shared.settings['mode'] != 'instruct') - - with gr.Tab('Chat settings', elem_id='chat-settings'): - - with gr.Tab("Character"): - with gr.Row(): - with gr.Column(scale=8): - with gr.Row(): - shared.gradio['character_menu'] = gr.Dropdown(value='None', choices=utils.get_available_characters(), label='Character', elem_id='character-menu', info='Used in chat and chat-instruct modes.', elem_classes='slim-dropdown') - ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': utils.get_available_characters()}, 'refresh-button') - shared.gradio['save_character'] = gr.Button('💾', elem_classes='refresh-button') - shared.gradio['delete_character'] = gr.Button('🗑️', elem_classes='refresh-button') - - shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') - shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name') - shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context') - shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting') - - with gr.Column(scale=1): - shared.gradio['character_picture'] = gr.Image(label='Character picture', type='pil') - shared.gradio['your_picture'] = gr.Image(label='Your picture', type='pil', value=Image.open(Path('cache/pfp_me.png')) if Path('cache/pfp_me.png').exists() else None) - - with gr.Tab("Instruction template"): - with gr.Row(): - with gr.Row(): - shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Instruction template', value='None', info='Change this according to the model/LoRA that you are using. Used in instruct and chat-instruct modes.', elem_classes='slim-dropdown') - ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button') - shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button') - shared.gradio['delete_template'] = gr.Button('🗑️ ', elem_classes='refresh-button') - - shared.gradio['name1_instruct'] = gr.Textbox(value='', lines=2, label='User string') - shared.gradio['name2_instruct'] = gr.Textbox(value='', lines=1, label='Bot string') - shared.gradio['context_instruct'] = gr.Textbox(value='', lines=4, label='Context') - shared.gradio['turn_template'] = gr.Textbox(value=shared.settings['turn_template'], lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.') - with gr.Row(): - shared.gradio['chat-instruct_command'] = gr.Textbox(value=shared.settings['chat-instruct_command'], lines=4, label='Command for chat-instruct mode', info='<|character|> gets replaced by the bot name, and <|prompt|> gets replaced by the regular chat prompt.') - - with gr.Tab('Chat history'): - with gr.Row(): - with gr.Column(): - shared.gradio['download'] = gr.File(label="Download") - shared.gradio['download_button'] = gr.Button(value='Refresh') - - with gr.Column(): - shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'], label="Upload") - - with gr.Tab('Upload character'): - with gr.Tab('JSON'): - with gr.Row(): - shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'], label='JSON File') - shared.gradio['upload_img_bot'] = gr.Image(type='pil', label='Profile Picture (optional)') - - shared.gradio['Submit character'] = gr.Button(value='Submit', interactive=False) - - with gr.Tab('TavernAI'): - with gr.Row(): - with gr.Column(): - shared.gradio['upload_img_tavern'] = gr.Image(type='pil', label='TavernAI PNG File', elem_id="upload_img_tavern") - shared.gradio['tavern_json'] = gr.State() - with gr.Column(): - shared.gradio['tavern_name'] = gr.Textbox(value='', lines=1, label='Name', interactive=False) - shared.gradio['tavern_desc'] = gr.Textbox(value='', lines=4, max_lines=4, label='Description', interactive=False) - - shared.gradio['Submit tavern character'] = gr.Button(value='Submit', interactive=False) - - with gr.Tab("Parameters", elem_id="parameters"): - create_settings_menus(default_preset) - - # Create notebook mode interface - elif shared.args.notebook: - shared.input_elements = ui.list_interface_input_elements() - shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) - shared.gradio['last_input'] = gr.State('') - with gr.Tab("Text generation", elem_id="main"): - with gr.Row(): - with gr.Column(scale=4): - with gr.Tab('Raw'): - shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes="textbox", lines=27) - - with gr.Tab('Markdown'): - shared.gradio['markdown_render'] = gr.Button('Render') - shared.gradio['markdown'] = gr.Markdown() - - with gr.Tab('HTML'): - shared.gradio['html'] = gr.HTML() - - with gr.Row(): - shared.gradio['Generate'] = gr.Button('Generate', variant='primary', elem_classes="small-button") - shared.gradio['Stop'] = gr.Button('Stop', elem_classes="small-button") - shared.gradio['Undo'] = gr.Button('Undo', elem_classes="small-button") - shared.gradio['Regenerate'] = gr.Button('Regenerate', elem_classes="small-button") - - with gr.Column(scale=1): - gr.HTML('
      ') - shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - with gr.Row(): - shared.gradio['prompt_menu'] = gr.Dropdown(choices=utils.get_available_prompts(), value='None', label='Prompt', elem_classes='slim-dropdown') - ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, ['refresh-button', 'refresh-button-small']) - shared.gradio['save_prompt'] = gr.Button('💾', elem_classes=['refresh-button', 'refresh-button-small']) - shared.gradio['delete_prompt'] = gr.Button('🗑️', elem_classes=['refresh-button', 'refresh-button-small']) - - shared.gradio['count_tokens'] = gr.Button('Count tokens') - shared.gradio['status'] = gr.Markdown('') - - with gr.Tab("Parameters", elem_id="parameters"): - create_settings_menus(default_preset) - - # Create default mode interface - else: - shared.input_elements = ui.list_interface_input_elements() - shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) - shared.gradio['last_input'] = gr.State('') - with gr.Tab("Text generation", elem_id="main"): - with gr.Row(): - with gr.Column(): - shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes="textbox_default", lines=27, label='Input') - shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - with gr.Row(): - shared.gradio['Generate'] = gr.Button('Generate', variant='primary') - shared.gradio['Stop'] = gr.Button('Stop') - shared.gradio['Continue'] = gr.Button('Continue') - shared.gradio['count_tokens'] = gr.Button('Count tokens') - - with gr.Row(): - shared.gradio['prompt_menu'] = gr.Dropdown(choices=utils.get_available_prompts(), value='None', label='Prompt', elem_classes='slim-dropdown') - ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, 'refresh-button') - shared.gradio['save_prompt'] = gr.Button('💾', elem_classes='refresh-button') - shared.gradio['delete_prompt'] = gr.Button('🗑️', elem_classes='refresh-button') - - shared.gradio['status'] = gr.Markdown('') - - with gr.Column(): - with gr.Tab('Raw'): - shared.gradio['output_textbox'] = gr.Textbox(elem_classes="textbox_default_output", lines=27, label='Output') - - with gr.Tab('Markdown'): - shared.gradio['markdown_render'] = gr.Button('Render') - shared.gradio['markdown'] = gr.Markdown() - - with gr.Tab('HTML'): - shared.gradio['html'] = gr.HTML() - - with gr.Tab("Parameters", elem_id="parameters"): - create_settings_menus(default_preset) - - # Model tab - with gr.Tab("Model", elem_id="model-tab"): - create_model_menus() - - # Training tab - with gr.Tab("Training", elem_id="training-tab"): - training.create_train_interface() - - # Session tab - with gr.Tab("Session", elem_id="session-tab"): - modes = ["default", "notebook", "chat"] - current_mode = "default" - for mode in modes[1:]: - if getattr(shared.args, mode): - current_mode = mode - break - - cmd_list = vars(shared.args) - bool_list = sorted([k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes + ui.list_model_elements()]) - bool_active = [k for k in bool_list if vars(shared.args)[k]] - - with gr.Row(): - - with gr.Column(): - with gr.Row(): - shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode", elem_classes='slim-dropdown') - shared.gradio['reset_interface'] = gr.Button("Apply and restart", elem_classes="small-button", variant="primary") - shared.gradio['toggle_dark_mode'] = gr.Button('Toggle 💡', elem_classes="small-button") - - with gr.Row(): - with gr.Column(): - shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=utils.get_available_extensions(), value=shared.args.extensions, label="Available extensions", info='Note that some of these extensions may require manually installing Python requirements through the command: pip install -r extensions/extension_name/requirements.txt', elem_classes='checkboxgroup-table') - - with gr.Column(): - shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags", elem_classes='checkboxgroup-table') - - with gr.Column(): - if not shared.args.multi_user: - with gr.Row(): - shared.gradio['session_menu'] = gr.Dropdown(choices=utils.get_available_sessions(), value='None', label='Session', elem_classes='slim-dropdown', info='When saving a session, make sure to keep the initial part of the filename (session_chat, session_notebook, or session_default), otherwise it will not appear on this list afterwards.') - ui.create_refresh_button(shared.gradio['session_menu'], lambda: None, lambda: {'choices': utils.get_available_sessions()}, ['refresh-button']) - shared.gradio['save_session'] = gr.Button('💾', elem_classes=['refresh-button']) - shared.gradio['delete_session'] = gr.Button('🗑️', elem_classes=['refresh-button']) - - extension_name = gr.Textbox(lines=1, label='Install or update an extension', info='Enter the GitHub URL below and press Enter. For a list of extensions, see: https://github.com/oobabooga/text-generation-webui-extensions ⚠️ WARNING ⚠️ : extensions can execute arbitrary code. Make sure to inspect their source code before activating them.') - extension_status = gr.Markdown() - - extension_name.submit( - clone_or_pull_repository, extension_name, extension_status, show_progress=False).then( - lambda: gr.update(choices=utils.get_available_extensions(), value=shared.args.extensions), None, gradio('extensions_menu')) - - # Reset interface event - shared.gradio['reset_interface'].click( - set_interface_arguments, gradio('interface_modes_menu', 'extensions_menu', 'bool_menu'), None).then( - lambda: None, None, None, _js='() => {document.body.innerHTML=\'

      Reloading...

      \'; setTimeout(function(){location.reload()},2500); return []}') - - shared.gradio['toggle_dark_mode'].click(lambda: None, None, None, _js='() => {document.getElementsByTagName("body")[0].classList.toggle("dark")}') - - # chat mode event handlers - if shared.is_chat(): - shared.input_params = gradio('Chat input', 'start_with', 'interface_state') - clear_arr = gradio('Clear history-confirm', 'Clear history', 'Clear history-cancel') - shared.reload_inputs = gradio('history', 'name1', 'name2', 'mode', 'chat_style') - - gen_events.append(shared.gradio['Generate'].click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - lambda x: (x, ''), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False).then( - chat.generate_chat_reply_wrapper, shared.input_params, gradio('display', 'history'), show_progress=False).then( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then( - lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") - ) - - gen_events.append(shared.gradio['textbox'].submit( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - lambda x: (x, ''), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False).then( - chat.generate_chat_reply_wrapper, shared.input_params, gradio('display', 'history'), show_progress=False).then( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then( - lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") - ) - - gen_events.append(shared.gradio['Regenerate'].click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - partial(chat.generate_chat_reply_wrapper, regenerate=True), shared.input_params, gradio('display', 'history'), show_progress=False).then( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then( - lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") - ) - - gen_events.append(shared.gradio['Continue'].click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - partial(chat.generate_chat_reply_wrapper, _continue=True), shared.input_params, gradio('display', 'history'), show_progress=False).then( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then( - lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") - ) - - gen_events.append(shared.gradio['Impersonate'].click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - lambda x: x, gradio('textbox'), gradio('Chat input'), show_progress=False).then( - chat.impersonate_wrapper, shared.input_params, gradio('textbox'), show_progress=False).then( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") - ) - - shared.gradio['Replace last reply'].click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - chat.replace_last_reply, gradio('textbox', 'interface_state'), gradio('history')).then( - lambda: '', None, gradio('textbox'), show_progress=False).then( - chat.redraw_html, shared.reload_inputs, gradio('display')).then( - chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None) - - shared.gradio['Send dummy message'].click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - chat.send_dummy_message, gradio('textbox', 'interface_state'), gradio('history')).then( - lambda: '', None, gradio('textbox'), show_progress=False).then( - chat.redraw_html, shared.reload_inputs, gradio('display')).then( - chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None) - - shared.gradio['Send dummy reply'].click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - chat.send_dummy_reply, gradio('textbox', 'interface_state'), gradio('history')).then( - lambda: '', None, gradio('textbox'), show_progress=False).then( - chat.redraw_html, shared.reload_inputs, gradio('display')).then( - chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None) - - shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) - shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) - shared.gradio['Clear history-confirm'].click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then( - chat.clear_chat_log, gradio('interface_state'), gradio('history')).then( - chat.redraw_html, shared.reload_inputs, gradio('display')).then( - chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None) - - shared.gradio['Remove last'].click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - chat.remove_last_message, gradio('history'), gradio('textbox', 'history'), show_progress=False).then( - chat.redraw_html, shared.reload_inputs, gradio('display')).then( - chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None) - - shared.gradio['character_menu'].change( - partial(chat.load_character, instruct=False), gradio('character_menu', 'name1', 'name2'), gradio('name1', 'name2', 'character_picture', 'greeting', 'context', 'dummy')).then( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - chat.load_persistent_history, gradio('interface_state'), gradio('history')).then( - chat.redraw_html, shared.reload_inputs, gradio('display')) - - shared.gradio['Stop'].click( - stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then( - chat.redraw_html, shared.reload_inputs, gradio('display')) - - shared.gradio['mode'].change( - lambda x: gr.update(visible=x != 'instruct'), gradio('mode'), gradio('chat_style'), show_progress=False).then( - chat.redraw_html, shared.reload_inputs, gradio('display')) - - shared.gradio['chat_style'].change(chat.redraw_html, shared.reload_inputs, gradio('display')) - shared.gradio['instruction_template'].change( - partial(chat.load_character, instruct=True), gradio('instruction_template', 'name1_instruct', 'name2_instruct'), gradio('name1_instruct', 'name2_instruct', 'dummy', 'dummy', 'context_instruct', 'turn_template')) - - shared.gradio['upload_chat_history'].upload( - chat.load_history, gradio('upload_chat_history', 'history'), gradio('history')).then( - chat.redraw_html, shared.reload_inputs, gradio('display')) - - shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, gradio('history'), gradio('textbox'), show_progress=False) - - # Save/delete a character - shared.gradio['save_character'].click( - lambda x: x, gradio('name2'), gradio('save_character_filename')).then( - lambda: gr.update(visible=True), None, gradio('character_saver')) - - shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, gradio('character_deleter')) - - shared.gradio['save_template'].click( - lambda: 'My Template.yaml', None, gradio('save_filename')).then( - lambda: 'characters/instruction-following/', None, gradio('save_root')).then( - chat.generate_instruction_template_yaml, gradio('name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template'), gradio('save_contents')).then( - lambda: gr.update(visible=True), None, gradio('file_saver')) - - shared.gradio['delete_template'].click( - lambda x: f'{x}.yaml', gradio('instruction_template'), gradio('delete_filename')).then( - lambda: 'characters/instruction-following/', None, gradio('delete_root')).then( - lambda: gr.update(visible=True), None, gradio('file_deleter')) - - shared.gradio['download_button'].click(chat.save_history, gradio('history'), gradio('download')) - shared.gradio['Submit character'].click(chat.upload_character, gradio('upload_json', 'upload_img_bot'), gradio('character_menu')) - shared.gradio['upload_json'].upload(lambda: gr.update(interactive=True), None, gradio('Submit character')) - shared.gradio['upload_json'].clear(lambda: gr.update(interactive=False), None, gradio('Submit character')) - - shared.gradio['Submit tavern character'].click(chat.upload_tavern_character, gradio('upload_img_tavern', 'tavern_json'), gradio('character_menu')) - shared.gradio['upload_img_tavern'].upload(chat.check_tavern_character, gradio('upload_img_tavern'), gradio('tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character'), show_progress=False) - shared.gradio['upload_img_tavern'].clear(lambda: (None, None, None, gr.update(interactive=False)), None, gradio('tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character'), show_progress=False) - shared.gradio['your_picture'].change( - chat.upload_your_profile_picture, gradio('your_picture'), None).then( - partial(chat.redraw_html, reset_cache=True), shared.reload_inputs, gradio('display')) - - # notebook/default modes event handlers - else: - shared.input_params = gradio('textbox', 'interface_state') - if shared.args.notebook: - output_params = gradio('textbox', 'html') - else: - output_params = gradio('output_textbox', 'html') - - gen_events.append(shared.gradio['Generate'].click( - lambda x: x, gradio('textbox'), gradio('last_input')).then( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") - # lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") - ) - - gen_events.append(shared.gradio['textbox'].submit( - lambda x: x, gradio('textbox'), gradio('last_input')).then( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") - # lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") - ) - - if shared.args.notebook: - shared.gradio['Undo'].click(lambda x: x, gradio('last_input'), gradio('textbox'), show_progress=False) - shared.gradio['markdown_render'].click(lambda x: x, gradio('textbox'), gradio('markdown'), queue=False) - gen_events.append(shared.gradio['Regenerate'].click( - lambda x: x, gradio('last_input'), gradio('textbox'), show_progress=False).then( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") - # lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") - ) - else: - shared.gradio['markdown_render'].click(lambda x: x, gradio('output_textbox'), gradio('markdown'), queue=False) - gen_events.append(shared.gradio['Continue'].click( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - generate_reply_wrapper, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=False).then( - ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") - # lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}") - ) - - shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None) - shared.gradio['prompt_menu'].change(load_prompt, gradio('prompt_menu'), gradio('textbox'), show_progress=False) - shared.gradio['save_prompt'].click( - lambda x: x, gradio('textbox'), gradio('save_contents')).then( - lambda: 'prompts/', None, gradio('save_root')).then( - lambda: utils.current_time() + '.txt', None, gradio('save_filename')).then( - lambda: gr.update(visible=True), None, gradio('file_saver')) - - shared.gradio['delete_prompt'].click( - lambda: 'prompts/', None, gradio('delete_root')).then( - lambda x: x + '.txt', gradio('prompt_menu'), gradio('delete_filename')).then( - lambda: gr.update(visible=True), None, gradio('file_deleter')) - - shared.gradio['count_tokens'].click(count_tokens, gradio('textbox'), gradio('status'), show_progress=False) - - create_file_saving_event_handlers() - - shared.gradio['interface'].load(lambda: None, None, None, _js=f"() => {{{js}}}") - shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, gradio(ui.list_interface_input_elements()), show_progress=False) + # Interface launch events if shared.settings['dark_theme']: shared.gradio['interface'].load(lambda: None, None, None, _js="() => document.getElementsByTagName('body')[0].classList.add('dark')") - if shared.is_chat(): - shared.gradio['interface'].load(chat.redraw_html, shared.reload_inputs, gradio('display')) + shared.gradio['interface'].load(lambda: None, None, None, _js=f"() => {{{js}}}") + shared.gradio['interface'].load(None, gradio('show_controls'), None, _js=f'(x) => {{{ui.show_controls_js}; toggle_controls(x)}}') + shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, gradio(ui.list_interface_input_elements()), show_progress=False) + shared.gradio['interface'].load(chat.redraw_html, gradio(ui_chat.reload_arr), gradio('display')) - # Extensions tabs - extensions_module.create_extensions_tabs() - - # Extensions block - extensions_module.create_extensions_block() + extensions_module.create_extensions_tabs() # Extensions tabs + extensions_module.create_extensions_block() # Extensions block # Launch the interface - shared.gradio['interface'].queue() + shared.gradio['interface'].queue(concurrency_count=64) with OpenMonkeyPatch(): - if shared.args.listen: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name=shared.args.listen_host or '0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth) - else: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth) + shared.gradio['interface'].launch( + prevent_thread_lock=True, + share=shared.args.share, + server_name=None if not shared.args.listen else (shared.args.listen_host or '0.0.0.0'), + server_port=shared.args.listen_port, + inbrowser=shared.args.auto_launch, + auth=auth or None, + ssl_verify=False if (shared.args.ssl_keyfile or shared.args.ssl_certfile) else True, + ssl_keyfile=shared.args.ssl_keyfile, + ssl_certfile=shared.args.ssl_certfile + ) if __name__ == "__main__": - # Loading custom settings + + # Load custom settings settings_file = None if shared.args.settings is not None and Path(shared.args.settings).exists(): settings_file = Path(shared.args.settings) @@ -1089,35 +168,18 @@ if __name__ == "__main__": logger.info(f"Loading settings from {settings_file}...") file_contents = open(settings_file, 'r', encoding='utf-8').read() new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents) - for item in new_settings: - shared.settings[item] = new_settings[item] - - # Set default model settings based on settings file - shared.model_config['.*'] = { - 'wbits': 'None', - 'model_type': 'None', - 'groupsize': 'None', - 'pre_layer': 0, - 'mode': shared.settings['mode'], - 'skip_special_tokens': shared.settings['skip_special_tokens'], - 'custom_stopping_strings': shared.settings['custom_stopping_strings'], - 'truncation_length': shared.settings['truncation_length'], - } + shared.settings.update(new_settings) + # Fallback settings for models + shared.model_config['.*'] = get_fallback_settings() shared.model_config.move_to_end('.*', last=False) # Move to the beginning - # Default extensions + # Activate the extensions listed on settings.yaml extensions_module.available_extensions = utils.get_available_extensions() - if shared.is_chat(): - for extension in shared.settings['chat_default_extensions']: - shared.args.extensions = shared.args.extensions or [] - if extension not in shared.args.extensions: - shared.args.extensions.append(extension) - else: - for extension in shared.settings['default_extensions']: - shared.args.extensions = shared.args.extensions or [] - if extension not in shared.args.extensions: - shared.args.extensions.append(extension) + for extension in shared.settings['default_extensions']: + shared.args.extensions = shared.args.extensions or [] + if extension not in shared.args.extensions: + shared.args.extensions.append(extension) available_models = utils.get_available_models() @@ -1125,10 +187,6 @@ if __name__ == "__main__": if shared.args.model is not None: shared.model_name = shared.args.model - # Only one model is available - elif len(available_models) == 1: - shared.model_name = available_models[0] - # Select the model from a command-line menu elif shared.args.model_menu: if len(available_models) == 0: @@ -1147,30 +205,22 @@ if __name__ == "__main__": # If any model has been selected, load it if shared.model_name != 'None': - model_settings = get_model_settings_from_yamls(shared.model_name) - shared.settings.update(model_settings) # hijacking the interface defaults + p = Path(shared.model_name) + if p.exists(): + model_name = p.parts[-1] + shared.model_name = model_name + else: + model_name = shared.model_name + + model_settings = get_model_metadata(model_name) + shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) # hijacking the interface defaults update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments # Load the model - shared.model, shared.tokenizer = load_model(shared.model_name) + shared.model, shared.tokenizer = load_model(model_name) if shared.args.lora: add_lora_to_model(shared.args.lora) - # Forcing some events to be triggered on page load - shared.persistent_interface_state.update({ - 'loader': shared.args.loader or 'Transformers', - }) - - if shared.is_chat(): - shared.persistent_interface_state.update({ - 'mode': shared.settings['mode'], - 'character_menu': shared.args.character or shared.settings['character'], - 'instruction_template': shared.settings['instruction_template'] - }) - - if Path("cache/pfp_character.png").exists(): - Path("cache/pfp_character.png").unlink() - shared.generation_lock = Lock() # Launch the web UI diff --git a/settings-template.yaml b/settings-template.yaml index e949f69..0696f50 100644 --- a/settings-template.yaml +++ b/settings-template.yaml @@ -1,40 +1,34 @@ -dark_theme: false -autoload_model: true +dark_theme: true +show_controls: true +start_with: '' +mode: chat +chat_style: cai-chat +prompt-default: QA +prompt-notebook: QA +preset: simple-1 max_new_tokens: 200 max_new_tokens_min: 1 -max_new_tokens_max: 2000 +max_new_tokens_max: 4096 seed: -1 -character: None -name1: You -name2: Assistant -context: This is a conversation with your Assistant. It is a computer program designed - to help you with various tasks such as answering questions, providing recommendations, - and helping with decision making. You can ask it anything you want and it will do - its best to give you accurate and relevant information. -greeting: '' -turn_template: '' -custom_stopping_strings: '' -stop_at_newline: false -add_bos_token: true -ban_eos_token: false -skip_special_tokens: true +negative_prompt: '' truncation_length: 2048 truncation_length_min: 0 truncation_length_max: 16384 -mode: chat -start_with: '' -chat_style: cai-chat -instruction_template: None -chat-instruct_command: 'Continue the chat dialogue below. Write a single reply for - the character "<|character|>". +custom_stopping_strings: '' +auto_max_new_tokens: false +max_tokens_second: 0 +ban_eos_token: false +custom_token_bans: '' +add_bos_token: true +skip_special_tokens: true +stream: true +name1: You +character: Assistant +instruction_template: Alpaca +chat-instruct_command: |- + Continue the chat dialogue below. Write a single reply for the character "<|character|>". - - <|prompt|>' -chat_generation_attempts: 1 -chat_generation_attempts_min: 1 -chat_generation_attempts_max: 10 -default_extensions: [] -chat_default_extensions: + <|prompt|> +autoload_model: false +default_extensions: - gallery -preset: simple-1 -prompt: QA diff --git a/training/datasets/put-trainer-datasets-here.txt b/training/datasets/put-trainer-datasets-here.txt index e69de29..932eacf 100644 --- a/training/datasets/put-trainer-datasets-here.txt +++ b/training/datasets/put-trainer-datasets-here.txt @@ -0,0 +1 @@ +to load multiple raw text files create a subdirectory and put them all there