使用 Gradio 进行演示

使用 Gradio 进行演示#

这个笔记本通过一个示例介绍了 Gradio 库,该库能够帮助您轻松创建演示界面

对象检测界面#

在这个示例中,我们使用一个图像对象检测模型,该模型在包含 80 个类别的 COCO 数据集上进行了训练。 具体采用的是 Meta 开发的 DETR 模型(facebook/detr-resnet-50)。

实现过程#

首先,我们使用 Hugging Face 的 transformers 库来构建处理流程

from PIL import Image
import matplotlib.pyplot as plt
from transformers import pipeline
import cv2
import numpy as np
/home/aquilae/anaconda3/envs/dev/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
detector = pipeline("object-detection", "facebook/detr-resnet-50")
Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
image = Image.open('images/coco3.jpg')
plt.imshow(image)
plt.axis('off')
plt.show()
../_images/f8ca149eec73faa44fab501414ad0c54b0f117ac45531e14e2c21ed30faa5b64.png

接下来进行对象检测并绘制边界框。为了保持结果清晰,我们仅标记图中的人。如果需要检测其他类别,可以移除此过滤条件。

def draw_boxes(image,output):
  cv_image = np.array(image)
  for bbox in output:
    box = bbox['box']
    label = bbox['label']
    if (label!="person"):
      continue
    cv2.rectangle(cv_image, (box['xmin'], box['ymin']), (box['xmax'], box['ymax']), (0, 0, 255), 1)
    cv2.putText(cv_image, label, (box['xmin'], box['ymin'] - 10), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
  return cv_image

output = detector(image)
cv_image=draw_boxes(image,output)
plt.imshow(cv_image)
plt.axis('off')
plt.show()
../_images/d8ed6e499f6b26631be8d7c3db3a085bc0e0065870d9824c9b4b5713cb68ae4f.png

Gradio 演示#

现在,我们将用 Gradio 构建一个演示界面。该界面的功能是:

  • 输入:一张图片

  • 输出:在同一张图片上标记出所有人的位置。

import gradio as gr
def get_pipeline_prediction(pil_image):
  pipeline_output = detector(pil_image)
  processed_image = draw_boxes(pil_image,pipeline_output)
  return processed_image
demo = gr.Interface(
  fn=get_pipeline_prediction,
  inputs=gr.Image(label="Image d'entrée",type="pil"),
  outputs=gr.Image(label="Image avec les personnes détectées",type="pil")
)
IMPORTANT: You are using gradio version 4.24.0, however version 4.29.0 is available, please upgrade.
--------
demo.launch()
Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.

这样,您的自定义演示界面就完成了! 您可以通过在 launch() 方法中设置 share=True 参数来共享这个演示。但请注意:

  • 共享期间需保持笔记本处于活动状态,否则演示将消失。

  • 如果希望无需保持电脑开机也能运行演示,可在 Hugging Face 网站上创建一个 Space(详见笔记本 1)。

注意:Gradio 的功能远不止于此。 如果您有更多定制需求,可查阅: