QJerry commited on
Commit
6cc7643
·
verified ·
1 Parent(s): 60a9385

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -20
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import json
2
  import logging
3
  import os
@@ -5,13 +6,12 @@ import random
5
  import re
6
  import sys
7
  import warnings
8
- from dataclasses import dataclass
9
 
 
 
10
  import gradio as gr
11
  import spaces
12
  import torch
13
- from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
14
- from PIL import Image
15
  from transformers import AutoModelForCausalLM, AutoTokenizer
16
 
17
  from prompt_check import is_unsafe_prompt
@@ -28,8 +28,10 @@ MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
28
  ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
29
  ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
30
  ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
 
31
  DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
32
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
33
  # =============================================================================
34
 
35
 
@@ -308,7 +310,7 @@ class APIPromptExpander(PromptExpander):
308
  if json_start != -1:
309
  json_end = content.find("```", json_start + 7)
310
  try:
311
- json_str = content[json_start + 7: json_end].strip()
312
  data = json.loads(json_str)
313
  expanded_prompt = data.get("revised_prompt", content)
314
  except:
@@ -377,8 +379,15 @@ def prompt_enhance(prompt, enable_enhance):
377
 
378
  @spaces.GPU
379
  def generate(
380
- prompt, resolution="1024x1024 ( 1:1 )", seed=42, steps=9, shift=3.0, random_seed=True, gallery_images=None,
381
- enhance=False, progress=gr.Progress(track_tqdm=True)
 
 
 
 
 
 
 
382
  ):
383
  """
384
  Generate an image using the Z-Image model based on the provided prompt and settings.
@@ -417,9 +426,15 @@ def generate(
417
  if pipe is None:
418
  raise gr.Error("Model not loaded.")
419
 
420
- has_nsfw_concept = is_unsafe_prompt(pipe.text_encoder, pipe.tokenizer, prompt)
421
- if has_nsfw_concept:
422
- raise UnsafeContentError("input unsafe")
 
 
 
 
 
 
423
 
424
  final_prompt = prompt
425
 
@@ -443,9 +458,7 @@ def generate(
443
  )
444
 
445
  safety_checker_input = pipe.safety_feature_extractor([image], return_tensors="pt").pixel_values.cuda()
446
- _, has_nsfw_concept = pipe.safety_checker(
447
- images=[torch.zeros(1)], clip_input=safety_checker_input
448
- )
449
  has_nsfw_concept = has_nsfw_concept[0]
450
  if has_nsfw_concept:
451
  raise UnsafeContentError("input unsafe")
@@ -493,8 +506,9 @@ with gr.Blocks(title="Z-Image Demo") as demo:
493
  res_cat = gr.Dropdown(value=1024, choices=choices, label="Resolution Category")
494
 
495
  initial_res_choices = RES_CHOICES["1024"]
496
- resolution = gr.Dropdown(value=initial_res_choices[0], choices=RESOLUTION_SET,
497
- label="Width x Height (Ratio)")
 
498
 
499
  with gr.Row():
500
  seed = gr.Number(label="Seed", value=42, precision=0)
@@ -512,12 +526,16 @@ with gr.Blocks(title="Z-Image Demo") as demo:
512
 
513
  with gr.Column(scale=1):
514
  output_gallery = gr.Gallery(
515
- label="Generated Images", columns=2, rows=2, height=600, object_fit="contain", format="png",
516
- interactive=False
 
 
 
 
 
517
  )
518
  used_seed = gr.Textbox(label="Seed Used", interactive=False)
519
 
520
-
521
  def update_res_choices(_res_cat):
522
  if str(_res_cat) in RES_CHOICES:
523
  res_choices = RES_CHOICES[str(_res_cat)]
@@ -525,7 +543,6 @@ with gr.Blocks(title="Z-Image Demo") as demo:
525
  res_choices = RES_CHOICES["1024"]
526
  return gr.update(value=res_choices[0], choices=res_choices)
527
 
528
-
529
  res_cat.change(update_res_choices, inputs=res_cat, outputs=resolution, api_visibility="private")
530
 
531
  # PE enhancement button (Temporarily disabled)
@@ -542,8 +559,8 @@ with gr.Blocks(title="Z-Image Demo") as demo:
542
  api_visibility="public",
543
  )
544
 
545
- css = '''
546
  .fillable{max-width: 1230px !important}
547
- '''
548
  if __name__ == "__main__":
549
  demo.launch(css=css, mcp_server=True)
 
1
+ from dataclasses import dataclass
2
  import json
3
  import logging
4
  import os
 
6
  import re
7
  import sys
8
  import warnings
 
9
 
10
+ from PIL import Image
11
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
12
  import gradio as gr
13
  import spaces
14
  import torch
 
 
15
  from transformers import AutoModelForCausalLM, AutoTokenizer
16
 
17
  from prompt_check import is_unsafe_prompt
 
28
  ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
29
  ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
30
  ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
31
+ UNSAFE_MAX_NEW_TOKEN = os.environ.get("UNSAFE_MAX_NEW_TOKEN", 10)
32
  DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
33
  HF_TOKEN = os.environ.get("HF_TOKEN")
34
+ UNSAFE_PROMPT_CHECK = os.environ.get("UNSAFE_PROMPT_CHECK")
35
  # =============================================================================
36
 
37
 
 
310
  if json_start != -1:
311
  json_end = content.find("```", json_start + 7)
312
  try:
313
+ json_str = content[json_start + 7 : json_end].strip()
314
  data = json.loads(json_str)
315
  expanded_prompt = data.get("revised_prompt", content)
316
  except:
 
379
 
380
  @spaces.GPU
381
  def generate(
382
+ prompt,
383
+ resolution="1024x1024 ( 1:1 )",
384
+ seed=42,
385
+ steps=9,
386
+ shift=3.0,
387
+ random_seed=True,
388
+ gallery_images=None,
389
+ enhance=False,
390
+ progress=gr.Progress(track_tqdm=True),
391
  ):
392
  """
393
  Generate an image using the Z-Image model based on the provided prompt and settings.
 
426
  if pipe is None:
427
  raise gr.Error("Model not loaded.")
428
 
429
+ has_unsafe_concept = is_unsafe_prompt(
430
+ pipe.text_encoder,
431
+ pipe.tokenizer,
432
+ system_prompt=UNSAFE_PROMPT_CHECK,
433
+ user_prompt=prompt,
434
+ max_new_token=UNSAFE_MAX_NEW_TOKEN,
435
+ )
436
+ if has_unsafe_concept:
437
+ raise UnsafeContentError("Input unsafe")
438
 
439
  final_prompt = prompt
440
 
 
458
  )
459
 
460
  safety_checker_input = pipe.safety_feature_extractor([image], return_tensors="pt").pixel_values.cuda()
461
+ _, has_nsfw_concept = pipe.safety_checker(images=[torch.zeros(1)], clip_input=safety_checker_input)
 
 
462
  has_nsfw_concept = has_nsfw_concept[0]
463
  if has_nsfw_concept:
464
  raise UnsafeContentError("input unsafe")
 
506
  res_cat = gr.Dropdown(value=1024, choices=choices, label="Resolution Category")
507
 
508
  initial_res_choices = RES_CHOICES["1024"]
509
+ resolution = gr.Dropdown(
510
+ value=initial_res_choices[0], choices=RESOLUTION_SET, label="Width x Height (Ratio)"
511
+ )
512
 
513
  with gr.Row():
514
  seed = gr.Number(label="Seed", value=42, precision=0)
 
526
 
527
  with gr.Column(scale=1):
528
  output_gallery = gr.Gallery(
529
+ label="Generated Images",
530
+ columns=2,
531
+ rows=2,
532
+ height=600,
533
+ object_fit="contain",
534
+ format="png",
535
+ interactive=False,
536
  )
537
  used_seed = gr.Textbox(label="Seed Used", interactive=False)
538
 
 
539
  def update_res_choices(_res_cat):
540
  if str(_res_cat) in RES_CHOICES:
541
  res_choices = RES_CHOICES[str(_res_cat)]
 
543
  res_choices = RES_CHOICES["1024"]
544
  return gr.update(value=res_choices[0], choices=res_choices)
545
 
 
546
  res_cat.change(update_res_choices, inputs=res_cat, outputs=resolution, api_visibility="private")
547
 
548
  # PE enhancement button (Temporarily disabled)
 
559
  api_visibility="public",
560
  )
561
 
562
+ css = """
563
  .fillable{max-width: 1230px !important}
564
+ """
565
  if __name__ == "__main__":
566
  demo.launch(css=css, mcp_server=True)