logo
4

Stable Diffusion中的插件功能实现案例

起因

之前用gradio搭一些项目的时候一般只有1~2个tab页,官方的技术文档已能满足需求,因此对组件的拓展性没有太关注。最近在研究Stable Diffusion,发现绝大部分web ui的实现都非常全面,一般都有十几个tab页,而且把鼠标移动到组件上,会发现有文字提示,如下图所示,这肯定不是原生gradio的组件。这引起了我的好奇心,gradio应该是比较简单,主要用于demo展示的,那这些花里胡哨的组件是怎么实现的?

探究

通过搜索相关的资料,发现大多插件都是以github仓库的形式存在的,那这就好办了,通过一个简单的项目即可窥斑见豹。在这里我找到一个插件:github.com/butaixianra… 这个是一个实现prompt翻译的插件,目录结构如下:
  
  
  
  
  
  
--img # demo用到的图片
|-- xxx.jpg
--javascript # js函数定义
|-- prompt_translator.js
--scripts
|-- services # 谷歌翻译相关的模块,这里以百度翻译为例,因此无需细究
|-- __init__.py
|-- google.py
|-- interface.py
|-- schema.py
|-- lang_code.py # 不同的翻译api中,语言和对应缩写。如:中文 对应 zh
|-- prompt_translator.py # 核心功能模块
--.gitignore
--LICENSE
--README.cn.md
--README.md
各个模块的功能如上所述,这里只摘取prompt_translator.py进行说明

分析

以下分析针对prompt_translator.py中的关键变量、函数进行说明。

1 从原始web-ui模块中引入的4个组件


这4个组件不是在这个项目中搭建的,而是从原始web-ui项目中引入(几乎所有开源的sd-web-ui实现都是在原始项目的基础上开发的),而这4个组件则代表最原始的功能,即文生图的prompt、文生图的负向prompt、图生图的prompt、图生图的负向prompt。当然原始项目还有很多其他组件,主要是因为这个插件是实现翻译,因此只需要引入这几个即可。
  
  
  
  
  
  
# 从原始web-ui项目中引入的4个组件
txt2img_prompt = modules.ui.txt2img_paste_fields[0][0]
txt2img_neg_prompt = modules.ui.txt2img_paste_fields[1][0]
img2img_prompt = modules.ui.img2img_paste_fields[0][0]
img2img_neg_prompt = modules.ui.img2img_paste_fields[1][0]

2 隐藏的按钮


trans_prompt_js_btntrans_neg_prompt_js_btn这两个按钮隐藏起来了,可以看到其初始化的属性:visible=False。因此对于该项目搭建的gradio app来说,这两个按钮是没用的,对应的绑定的两个监听函数也是没用的。 用更简单的话说,如果用该项目启动gradio app,不会出现这两个按钮。如单独部署的界面。
  
  
  
  
  
  
# gradio搭建
with gr.Blocks(analytics_enabled=False) as prompt_translator:
# ====ui====
# Prompt Area
with gr.Row():
tar_lang_drop = gr.Dropdown(label="Target Language", choices=tar_langs, value=def_tar_lang, elem_id="pt_tar_lang")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=3, value="", elem_id="pt_prompt")
translated_prompt = gr.Textbox(label="Translated Prompt", lines=3, value="", elem_id="pt_translated_prompt")
with gr.Row():
trans_prompt_btn = gr.Button(value="Translate", elem_id="pt_trans_prompt_btn")
# add a hidden button, used by fake click with javascript. To simulate msg between server and client side.
# this is the only way.
# 隐藏的按钮,用于js触发
trans_prompt_js_btn = gr.Button(value="Trans Js", visible=False, elem_id="pt_trans_prompt_js_btn")
send_prompt_btn = gr.Button(value="Send to txt2img and img2img", elem_id="pt_send_prompt_btn")
with gr.Row():
neg_prompt = gr.Textbox(label="Negative Prompt", lines=2, value="", elem_id="pt_neg_prompt")
translated_neg_prompt = gr.Textbox(label="Translated Negative Prompt", lines=2, value=json.dumps(tar_langs), elem_id="pt_translated_neg_prompt")
with gr.Row():
trans_neg_prompt_btn = gr.Button(value="Translate", elem_id="pt_trans_neg_prompt_btn")
# add a hidden button, used by fake click with javascript. To simulate msg between server and client side.
# this is the only way.
# 隐藏的按钮,用于js触发
trans_neg_prompt_js_btn = gr.Button(value="Trans Js", visible=False, elem_id="pt_trans_neg_prompt_js_btn")
send_neg_prompt_btn = gr.Button(value="Send to txt2img and img2img", elem_id="pt_send_neg_prompt_btn")

3 交互


上述代码再往后看,会发现该模块最后有这么一行代码:
  
  
  
  
  
  
script_callbacks.on_ui_tabs(on_ui_tabs)
经过对比,发现这是用了原始web-ui项目中提供的一个回调方法,实现插件功能,即把这个插件页面补充到原始web-ui项目中去。

4 原始项目中的方法


查看回原始项目(github.com/AUTOMATIC11…),发现2个模块比较关键。 对于modules/script_callbacks.py,定义了一些把需要新增的插件补充进原始项目中去的方法,这里仅列出2个跟本翻译插件项目关联最紧密的方法说明。
  
  
  
  
  
  
# file: modules/script_callbacks.py
# 把当前的插件添加进callback_map字典中
def on_ui_tabs(callback):
"""register a function to be called when the UI is creating new tabs.
The function must either return a None, which means no new tabs to be added, or a list, where
each element is a tuple:
(gradio_component, title, elem_id)
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
add_callback(callback_map['callbacks_ui_tabs'], callback)
# 遍历callback_map字典,把新增的插件都添加进来
def ui_tabs_callback():
res = []
for c in callback_map['callbacks_ui_tabs']:
try:
res += c.callback() or []
except Exception:
report_exception(c, 'ui_tabs_callback')
return res
对于modules/ui.py,即实现了原始项目的核心gradio界面功能,通常初始化6个tab页面,然后通过上述回调方法,把各种需要的插件功能添加进原始项目中去。
  
  
  
  
  
  
# file: modules/ui.py
# ...
# 前面部分与通常的gradio搭建大同小异
# 初始的6个tab页
interfaces = [
(txt2img_interface, "txt2img", "txt2img"),
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "train"),
]
# 通过ui_tabs_callback方法实现tab的拓展,这些拓展的tab即插件,如本文说的翻译插件
interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings.interface, "Settings", "settings")]

总结

通过一个翻译插件项目,我了解到怎么在原始的的gradio项目中添加额外的插件功能,而无需进行太多改动。而且这些插件功能通过js可以实现各种额外的功能,大大弥补了gradio的弱点。
版权声明:本文为稀土掘金博主「深度学习机器」的原创文章
如有侵权,请联系千帆社区进行删除
评论
用户头像