Skip to content

Commit c3324d7

Browse files
abidlabsgradio-pr-bot
andauthoredNov 20, 2024··
Fix issues related to examples and example caching in gr.ChatInterface (#9990)
* changes * changes * add functional tests * add changeset * revert * example format * chat interface * replace attribute with str * replace attribute with function * fix tests * changes * fix * more changes * changes * changes * demo * more changes * typing * demos * test * changes * changes * functional tests * add changeset * fix pytest --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
1 parent e7629f7 commit c3324d7

20 files changed

+216
-130
lines changed
 

‎.changeset/brown-hounds-crash.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"gradio": patch
3+
---
4+
5+
fix:Fix issues related to examples and example caching in `gr.ChatInterface`

‎demo/agent_chatbot/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# type: ignore
12
from __future__ import annotations
23

34
from gradio import ChatMessage

‎demo/bokeh_plot/run.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: bokeh_plot"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio bokeh>=3.0 xyzservices"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import xyzservices.providers as xyz\n", "from bokeh.models import ColumnDataSource, Whisker\n", "from bokeh.plotting import figure\n", "from bokeh.sampledata.autompg2 import autompg2 as df\n", "from bokeh.sampledata.penguins import data\n", "from bokeh.transform import factor_cmap, jitter, factor_mark\n", "\n", "def get_plot(plot_type):\n", " if plot_type == \"map\":\n", " plot = figure(\n", " x_range=(-2000000, 6000000),\n", " y_range=(-1000000, 7000000),\n", " x_axis_type=\"mercator\",\n", " y_axis_type=\"mercator\",\n", " )\n", " plot.add_tile(xyz.OpenStreetMap.Mapnik) # type: ignore\n", " return plot\n", " elif plot_type == \"whisker\":\n", " classes = sorted(df[\"class\"].unique())\n", "\n", " p = figure(\n", " height=400,\n", " x_range=classes,\n", " background_fill_color=\"#efefef\",\n", " title=\"Car class vs HWY mpg with quintile ranges\",\n", " )\n", " p.xgrid.grid_line_color = None\n", "\n", " g = df.groupby(\"class\")\n", " upper = g.hwy.quantile(0.80)\n", " lower = g.hwy.quantile(0.20)\n", " source = ColumnDataSource(data=dict(base=classes, upper=upper, lower=lower))\n", "\n", " error = Whisker(\n", " base=\"base\",\n", " upper=\"upper\",\n", " lower=\"lower\",\n", " source=source,\n", " level=\"annotation\",\n", " line_width=2,\n", " )\n", " error.upper_head.size = 20\n", " error.lower_head.size = 20\n", " p.add_layout(error)\n", "\n", " p.circle(\n", " jitter(\"class\", 0.3, range=p.x_range),\n", " \"hwy\",\n", " source=df,\n", " alpha=0.5,\n", " size=13,\n", " line_color=\"white\",\n", " color=factor_cmap(\"class\", \"Light6\", classes),\n", " )\n", " return p\n", " elif plot_type == \"scatter\":\n", "\n", " SPECIES = sorted(data.species.unique())\n", " MARKERS = [\"hex\", \"circle_x\", \"triangle\"]\n", "\n", " p = figure(title=\"Penguin size\", background_fill_color=\"#fafafa\")\n", " p.xaxis.axis_label = \"Flipper Length (mm)\"\n", " p.yaxis.axis_label = \"Body Mass (g)\"\n", "\n", " p.scatter(\n", " \"flipper_length_mm\",\n", " \"body_mass_g\",\n", " source=data,\n", " legend_group=\"species\",\n", " fill_alpha=0.4,\n", " size=12,\n", " marker=factor_mark(\"species\", MARKERS, SPECIES),\n", " color=factor_cmap(\"species\", \"Category10_3\", SPECIES),\n", " )\n", "\n", " p.legend.location = \"top_left\"\n", " p.legend.title = \"Species\"\n", " return p\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " plot_type = gr.Radio(value=\"scatter\", choices=[\"scatter\", \"whisker\", \"map\"])\n", " plot = gr.Plot()\n", " plot_type.change(get_plot, inputs=[plot_type], outputs=[plot])\n", " demo.load(get_plot, inputs=[plot_type], outputs=[plot])\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
1+
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: bokeh_plot"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio bokeh>=3.0 xyzservices"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# type: ignore\n", "import gradio as gr\n", "import xyzservices.providers as xyz\n", "from bokeh.models import ColumnDataSource, Whisker\n", "from bokeh.plotting import figure\n", "from bokeh.sampledata.autompg2 import autompg2 as df\n", "from bokeh.sampledata.penguins import data\n", "from bokeh.transform import factor_cmap, jitter, factor_mark\n", "\n", "def get_plot(plot_type):\n", " if plot_type == \"map\":\n", " plot = figure(\n", " x_range=(-2000000, 6000000),\n", " y_range=(-1000000, 7000000),\n", " x_axis_type=\"mercator\",\n", " y_axis_type=\"mercator\",\n", " )\n", " plot.add_tile(xyz.OpenStreetMap.Mapnik) # type: ignore\n", " return plot\n", " elif plot_type == \"whisker\":\n", " classes = sorted(df[\"class\"].unique())\n", "\n", " p = figure(\n", " height=400,\n", " x_range=classes,\n", " background_fill_color=\"#efefef\",\n", " title=\"Car class vs HWY mpg with quintile ranges\",\n", " )\n", " p.xgrid.grid_line_color = None\n", "\n", " g = df.groupby(\"class\")\n", " upper = g.hwy.quantile(0.80)\n", " lower = g.hwy.quantile(0.20)\n", " source = ColumnDataSource(data=dict(base=classes, upper=upper, lower=lower))\n", "\n", " error = Whisker(\n", " base=\"base\",\n", " upper=\"upper\",\n", " lower=\"lower\",\n", " source=source,\n", " level=\"annotation\",\n", " line_width=2,\n", " )\n", " error.upper_head.size = 20\n", " error.lower_head.size = 20\n", " p.add_layout(error)\n", "\n", " p.circle(\n", " jitter(\"class\", 0.3, range=p.x_range),\n", " \"hwy\",\n", " source=df,\n", " alpha=0.5,\n", " size=13,\n", " line_color=\"white\",\n", " color=factor_cmap(\"class\", \"Light6\", classes),\n", " )\n", " return p\n", " elif plot_type == \"scatter\":\n", "\n", " SPECIES = sorted(data.species.unique())\n", " MARKERS = [\"hex\", \"circle_x\", \"triangle\"]\n", "\n", " p = figure(title=\"Penguin size\", background_fill_color=\"#fafafa\")\n", " p.xaxis.axis_label = \"Flipper Length (mm)\"\n", " p.yaxis.axis_label = \"Body Mass (g)\"\n", "\n", " p.scatter(\n", " \"flipper_length_mm\",\n", " \"body_mass_g\",\n", " source=data,\n", " legend_group=\"species\",\n", " fill_alpha=0.4,\n", " size=12,\n", " marker=factor_mark(\"species\", MARKERS, SPECIES),\n", " color=factor_cmap(\"species\", \"Category10_3\", SPECIES),\n", " )\n", "\n", " p.legend.location = \"top_left\"\n", " p.legend.title = \"Species\"\n", " return p\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " plot_type = gr.Radio(value=\"scatter\", choices=[\"scatter\", \"whisker\", \"map\"])\n", " plot = gr.Plot()\n", " plot_type.change(get_plot, inputs=[plot_type], outputs=[plot])\n", " demo.load(get_plot, inputs=[plot_type], outputs=[plot])\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

‎demo/bokeh_plot/run.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# type: ignore
12
import gradio as gr
23
import xyzservices.providers as xyz
34
from bokeh.models import ColumnDataSource, Whisker

‎demo/chatbot_core_components/run.ipynb

+1-1
Large diffs are not rendered by default.

‎demo/chatbot_core_components/run.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# type: ignore
12
import gradio as gr
23
import os
34
import plotly.express as px
+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: rt-detr-object-detection"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio safetensors==0.4.3 opencv-python torch transformers>=4.43.0 Pillow "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/rt-detr-object-detection/3285790-hd_1920_1080_30fps.mp4\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/rt-detr-object-detection/draw_boxes.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import spaces\n", "import gradio as gr\n", "import cv2\n", "from PIL import Image\n", "import torch\n", "import time\n", "import numpy as np\n", "import uuid\n", "\n", "from transformers import RTDetrForObjectDetection, RTDetrImageProcessor # type: ignore\n", "\n", "from draw_boxes import draw_bounding_boxes\n", "\n", "image_processor = RTDetrImageProcessor.from_pretrained(\"PekingU/rtdetr_r50vd\")\n", "model = RTDetrForObjectDetection.from_pretrained(\"PekingU/rtdetr_r50vd\").to(\"cuda\")\n", "\n", "\n", "SUBSAMPLE = 2\n", "\n", "\n", "@spaces.GPU\n", "def stream_object_detection(video, conf_threshold):\n", " cap = cv2.VideoCapture(video)\n", "\n", " video_codec = cv2.VideoWriter_fourcc(*\"mp4v\") # type: ignore\n", " fps = int(cap.get(cv2.CAP_PROP_FPS))\n", "\n", " desired_fps = fps // SUBSAMPLE\n", " width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2\n", " height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) // 2\n", "\n", " iterating, frame = cap.read()\n", "\n", " n_frames = 0\n", "\n", " name = f\"output_{uuid.uuid4()}.mp4\"\n", " segment_file = cv2.VideoWriter(name, video_codec, desired_fps, (width, height)) # type: ignore\n", " batch = []\n", "\n", " while iterating:\n", " frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)\n", " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", " if n_frames % SUBSAMPLE == 0:\n", " batch.append(frame)\n", " if len(batch) == 2 * desired_fps:\n", " inputs = image_processor(images=batch, return_tensors=\"pt\").to(\"cuda\")\n", "\n", " print(f\"starting batch of size {len(batch)}\")\n", " start = time.time()\n", " with torch.no_grad():\n", " outputs = model(**inputs)\n", " end = time.time()\n", " print(\"time taken for inference\", end - start)\n", "\n", " start = time.time()\n", " boxes = image_processor.post_process_object_detection(\n", " outputs,\n", " target_sizes=torch.tensor([(height, width)] * len(batch)),\n", " threshold=conf_threshold,\n", " )\n", "\n", " for _, (array, box) in enumerate(zip(batch, boxes)):\n", " pil_image = draw_bounding_boxes(\n", " Image.fromarray(array), box, model, conf_threshold\n", " )\n", " frame = np.array(pil_image)\n", " # Convert RGB to BGR\n", " frame = frame[:, :, ::-1].copy()\n", " segment_file.write(frame)\n", "\n", " batch = []\n", " segment_file.release()\n", " yield name\n", " end = time.time()\n", " print(\"time taken for processing boxes\", end - start)\n", " name = f\"output_{uuid.uuid4()}.mp4\"\n", " segment_file = cv2.VideoWriter(\n", " name, video_codec, desired_fps, (width, height)\n", " ) # type: ignore\n", "\n", " iterating, frame = cap.read()\n", " n_frames += 1\n", "\n", "\n", "with gr.Blocks() as demo:\n", " gr.HTML(\n", " \"\"\"\n", " <h1 style='text-align: center'>\n", " Video Object Detection with <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>RT-DETR</a>\n", " </h1>\n", " \"\"\"\n", " )\n", " with gr.Row():\n", " with gr.Column():\n", " video = gr.Video(label=\"Video Source\")\n", " conf_threshold = gr.Slider(\n", " label=\"Confidence Threshold\",\n", " minimum=0.0,\n", " maximum=1.0,\n", " step=0.05,\n", " value=0.30,\n", " )\n", " with gr.Column():\n", " output_video = gr.Video(\n", " label=\"Processed Video\", streaming=True, autoplay=True\n", " )\n", "\n", " video.upload(\n", " fn=stream_object_detection,\n", " inputs=[video, conf_threshold],\n", " outputs=[output_video],\n", " )\n", "\n", " gr.Examples(\n", " examples=[\"3285790-hd_1920_1080_30fps.mp4\"],\n", " inputs=[video],\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
1+
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: rt-detr-object-detection"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio safetensors==0.4.3 opencv-python torch transformers>=4.43.0 Pillow "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/rt-detr-object-detection/3285790-hd_1920_1080_30fps.mp4\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/rt-detr-object-detection/draw_boxes.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["# type: ignore\n", "import spaces\n", "import gradio as gr\n", "import cv2\n", "from PIL import Image\n", "import torch\n", "import time\n", "import numpy as np\n", "import uuid\n", "\n", "from transformers import RTDetrForObjectDetection, RTDetrImageProcessor # type: ignore\n", "\n", "from draw_boxes import draw_bounding_boxes\n", "\n", "image_processor = RTDetrImageProcessor.from_pretrained(\"PekingU/rtdetr_r50vd\")\n", "model = RTDetrForObjectDetection.from_pretrained(\"PekingU/rtdetr_r50vd\").to(\"cuda\")\n", "\n", "\n", "SUBSAMPLE = 2\n", "\n", "\n", "@spaces.GPU\n", "def stream_object_detection(video, conf_threshold):\n", " cap = cv2.VideoCapture(video)\n", "\n", " video_codec = cv2.VideoWriter_fourcc(*\"mp4v\") # type: ignore\n", " fps = int(cap.get(cv2.CAP_PROP_FPS))\n", "\n", " desired_fps = fps // SUBSAMPLE\n", " width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2\n", " height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) // 2\n", "\n", " iterating, frame = cap.read()\n", "\n", " n_frames = 0\n", "\n", " name = f\"output_{uuid.uuid4()}.mp4\"\n", " segment_file = cv2.VideoWriter(name, video_codec, desired_fps, (width, height)) # type: ignore\n", " batch = []\n", "\n", " while iterating:\n", " frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)\n", " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", " if n_frames % SUBSAMPLE == 0:\n", " batch.append(frame)\n", " if len(batch) == 2 * desired_fps:\n", " inputs = image_processor(images=batch, return_tensors=\"pt\").to(\"cuda\")\n", "\n", " print(f\"starting batch of size {len(batch)}\")\n", " start = time.time()\n", " with torch.no_grad():\n", " outputs = model(**inputs)\n", " end = time.time()\n", " print(\"time taken for inference\", end - start)\n", "\n", " start = time.time()\n", " boxes = image_processor.post_process_object_detection(\n", " outputs,\n", " target_sizes=torch.tensor([(height, width)] * len(batch)),\n", " threshold=conf_threshold,\n", " )\n", "\n", " for _, (array, box) in enumerate(zip(batch, boxes)):\n", " pil_image = draw_bounding_boxes(\n", " Image.fromarray(array), box, model, conf_threshold\n", " )\n", " frame = np.array(pil_image)\n", " # Convert RGB to BGR\n", " frame = frame[:, :, ::-1].copy()\n", " segment_file.write(frame)\n", "\n", " batch = []\n", " segment_file.release()\n", " yield name\n", " end = time.time()\n", " print(\"time taken for processing boxes\", end - start)\n", " name = f\"output_{uuid.uuid4()}.mp4\"\n", " segment_file = cv2.VideoWriter(\n", " name, video_codec, desired_fps, (width, height)\n", " ) # type: ignore\n", "\n", " iterating, frame = cap.read()\n", " n_frames += 1\n", "\n", "\n", "with gr.Blocks() as demo:\n", " gr.HTML(\n", " \"\"\"\n", " <h1 style='text-align: center'>\n", " Video Object Detection with <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>RT-DETR</a>\n", " </h1>\n", " \"\"\"\n", " )\n", " with gr.Row():\n", " with gr.Column():\n", " video = gr.Video(label=\"Video Source\")\n", " conf_threshold = gr.Slider(\n", " label=\"Confidence Threshold\",\n", " minimum=0.0,\n", " maximum=1.0,\n", " step=0.05,\n", " value=0.30,\n", " )\n", " with gr.Column():\n", " output_video = gr.Video(\n", " label=\"Processed Video\", streaming=True, autoplay=True\n", " )\n", "\n", " video.upload(\n", " fn=stream_object_detection,\n", " inputs=[video, conf_threshold],\n", " outputs=[output_video],\n", " )\n", "\n", " gr.Examples(\n", " examples=[\"3285790-hd_1920_1080_30fps.mp4\"],\n", " inputs=[video],\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

‎demo/rt-detr-object-detection/run.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# type: ignore
12
import spaces
23
import gradio as gr
34
import cv2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import gradio as gr
2+
3+
def generate(
4+
message: str,
5+
chat_history: list[dict],
6+
):
7+
8+
output = ""
9+
for character in message:
10+
output += character
11+
yield output
12+
13+
14+
demo = gr.ChatInterface(
15+
fn=generate,
16+
examples=[
17+
["Hey"],
18+
["Can you explain briefly to me what is the Python programming language?"],
19+
],
20+
cache_examples=True,
21+
cache_mode="eager",
22+
type="messages",
23+
)
24+
25+
26+
if __name__ == "__main__":
27+
demo.launch()

‎demo/test_chatinterface_examples/multimodal_messages_examples_testcase.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ def generate(
2424

2525

2626
if __name__ == "__main__":
27-
demo.launch()
27+
demo.launch()

‎demo/test_chatinterface_examples/multimodal_tuples_examples_testcase.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ def generate(
2424

2525

2626
if __name__ == "__main__":
27-
demo.launch()
27+
demo.launch()
+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_examples"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/multimodal_messages_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/multimodal_tuples_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/tuples_examples_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "def generate(\n", " message: str,\n", " chat_history: list[dict],\n", "):\n", "\n", " output = \"\"\n", " for character in message:\n", " output += character\n", " yield output\n", "\n", "\n", "demo = gr.ChatInterface(\n", " fn=generate,\n", " examples=[\n", " [\"Hey\"],\n", " [\"Can you explain briefly to me what is the Python programming language?\"],\n", " ],\n", " cache_examples=False,\n", " type=\"messages\",\n", ")\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
1+
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_examples"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/eager_caching_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/multimodal_messages_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/multimodal_tuples_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/tuples_examples_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "def generate(\n", " message: str,\n", " chat_history: list[dict],\n", "):\n", "\n", " output = \"\"\n", " for character in message:\n", " output += character\n", " yield output\n", "\n", "\n", "demo = gr.ChatInterface(\n", " fn=generate,\n", " examples=[\n", " [\"Hey\"],\n", " [\"Can you explain briefly to me what is the Python programming language?\"],\n", " ],\n", " cache_examples=False,\n", " type=\"messages\",\n", ")\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

‎demo/test_chatinterface_examples/run.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ def generate(
2323

2424

2525
if __name__ == "__main__":
26-
demo.launch()
26+
demo.launch()

‎demo/test_chatinterface_examples/tuples_examples_testcase.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ def generate(
2424

2525

2626
if __name__ == "__main__":
27-
demo.launch()
27+
demo.launch()

‎gradio/chat_interface.py

+112-115
Large diffs are not rendered by default.

‎gradio/components/chatbot.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class ChatbotDataMessages(GradioRootModel):
108108
root: list[Message]
109109

110110

111-
TupleFormat = list[
111+
TupleFormat = Sequence[
112112
tuple[Union[str, tuple[str], None], Union[str, tuple[str], None]]
113113
| list[Union[str, tuple[str], None]]
114114
]
@@ -227,7 +227,7 @@ def __init__(
227227
"""
228228
if type is None:
229229
warnings.warn(
230-
"You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style 'role' and 'content' keys.",
230+
"You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style dictionaries with 'role' and 'content' keys.",
231231
UserWarning,
232232
)
233233
type = "tuples"
@@ -315,7 +315,7 @@ def __init__(
315315
file_info[i] = file_data
316316

317317
@staticmethod
318-
def _check_format(messages: list[Any], type: Literal["messages", "tuples"]):
318+
def _check_format(messages: Any, type: Literal["messages", "tuples"]):
319319
if type == "messages":
320320
all_valid = all(
321321
isinstance(message, dict)

‎gradio/processing_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
7676
return httpx.Response(
7777
status_code=response.status,
7878
headers=response_headers,
79-
stream=Urllib3ResponseSyncByteStream(response),
79+
stream=Urllib3ResponseSyncByteStream(response), # type: ignore
8080
)
8181

8282
sync_transport = Urllib3Transport()

‎js/spa/test/test_chatinterface_examples.spec.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ const cases = [
44
"messages",
55
"tuples_examples",
66
"multimodal_tuples_examples",
7-
"multimodal_messages_examples"
7+
"multimodal_messages_examples",
8+
"eager_caching_examples"
89
];
910

1011
for (const test_case of cases) {
11-
test(`test case ${test_case} clicking example properly adds it to the history and passes the correct values to the prediction function`, async ({
12+
test(`case ${test_case}: clicked example is added to history and passed to chat function`, async ({
1213
page
1314
}) => {
1415
if (cases.slice(1).includes(test_case)) {

‎test/components/test_image.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from gradio_client import utils as client_utils
88

99
import gradio as gr
10-
from gradio.components.image import ImageData
10+
from gradio.components.image import ImageData # type: ignore
1111
from gradio.exceptions import Error
1212

1313

‎test/test_chat_interface.py

+52
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,55 @@ def double_multimodal(msg, history):
305305
with connect(chatbot) as client:
306306
result = client.predict({"text": "hello", "files": []}, api_name="/chat")
307307
assert result == "hello hello"
308+
309+
310+
class TestExampleMessages:
311+
def test_setup_example_messages_with_strings(self):
312+
chat = gr.ChatInterface(
313+
double,
314+
examples=["hello", "hi", "hey"],
315+
example_labels=["Greeting 1", "Greeting 2", "Greeting 3"],
316+
)
317+
assert len(chat.examples_messages) == 3
318+
assert chat.examples_messages[0] == {
319+
"text": "hello",
320+
"display_text": "Greeting 1",
321+
}
322+
assert chat.examples_messages[1] == {
323+
"text": "hi",
324+
"display_text": "Greeting 2",
325+
}
326+
assert chat.examples_messages[2] == {
327+
"text": "hey",
328+
"display_text": "Greeting 3",
329+
}
330+
331+
def test_setup_example_messages_with_multimodal(self):
332+
chat = gr.ChatInterface(
333+
double,
334+
examples=[
335+
{"text": "hello", "files": ["file1.txt"]},
336+
{"text": "hi", "files": ["file2.txt", "file3.txt"]},
337+
{"text": "", "files": ["file4.txt"]},
338+
],
339+
)
340+
assert len(chat.examples_messages) == 3
341+
assert chat.examples_messages[0]["text"] == "hello" # type: ignore
342+
assert chat.examples_messages[0]["files"][0]["path"].endswith("file1.txt") # type: ignore
343+
344+
def test_setup_example_messages_with_lists(self):
345+
chat = gr.ChatInterface(
346+
double,
347+
examples=[
348+
["hello", "other_value"],
349+
["hi", "another_value"],
350+
],
351+
)
352+
assert len(chat.examples_messages) == 2
353+
assert chat.examples_messages[0] == {"text": "hello"}
354+
assert chat.examples_messages[1] == {"text": "hi"}
355+
356+
def test_setup_example_messages_empty(self):
357+
chat = gr.ChatInterface(double)
358+
chat._setup_example_messages(None)
359+
assert chat.examples_messages == []

0 commit comments

Comments
 (0)
Please sign in to comment.