Fucius commited on
Commit
7a6033e
1 Parent(s): 78ed006

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -107
app.py CHANGED
@@ -368,75 +368,103 @@ def main(device, segment_type):
368
 
369
  @spaces.GPU
370
  def generate_image(prompt1, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style):
371
- try:
372
- path1 = lorapath_man[man]
373
- path2 = lorapath_woman[woman]
374
- pipe_concept.unload_lora_weights()
375
- pipe.unload_lora_weights()
376
- pipe_list = build_model_lora(pipe_concept, path1 + "|" + path2, lorapath_styles[style], condition, args, pipe)
377
-
378
- if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
379
- styleL = True
380
- else:
381
- styleL = False
382
-
383
- input_list = [prompt1]
384
- condition_list = [condition_img1]
385
- output_list = []
386
-
387
- width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
388
-
389
- kwargs = {
390
- 'height': height,
391
- 'width': width,
392
- }
393
-
394
- for prompt, condition_img in zip(input_list, condition_list):
395
- if prompt!='':
396
- input_prompt = []
397
- p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
398
- if styleL:
399
- p = styles[style] + p
400
- input_prompt.append([p.replace("{prompt}", prompt), p.replace("{prompt}", prompt)])
401
- if styleL:
402
- input_prompt.append([(styles[style] + local_prompt1, character_man.get(man)[1]),
403
- (styles[style] + local_prompt2, character_woman.get(woman)[1])])
404
- else:
405
- input_prompt.append([(local_prompt1, character_man.get(man)[1]),
406
- (local_prompt2, character_woman.get(woman)[1])])
407
-
408
- if condition == 'Human pose' and condition_img is not None:
409
- index = ratio_list.index(
410
- min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
411
- resolution = resolution_list[index]
412
- width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
413
- kwargs['height'] = height
414
- kwargs['width'] = width
415
- condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
416
- spatial_condition = get_humanpose(condition_img)
417
- elif condition == 'Canny Edge' and condition_img is not None:
418
- index = ratio_list.index(
419
- min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
420
- resolution = resolution_list[index]
421
- width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
422
- kwargs['height'] = height
423
- kwargs['width'] = width
424
- condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
425
- spatial_condition = get_cannyedge(condition_img)
426
- elif condition == 'Depth' and condition_img is not None:
427
- index = ratio_list.index(
428
- min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
429
- resolution = resolution_list[index]
430
- width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
431
- kwargs['height'] = height
432
- kwargs['width'] = width
433
- condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
434
- spatial_condition = get_depth(condition_img)
435
- else:
436
- spatial_condition = None
437
-
438
- kwargs['spatial_condition'] = spatial_condition
439
- controller.reset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  image = sample_image(
441
  pipe,
442
  input_prompt=input_prompt,
@@ -444,47 +472,19 @@ def main(device, segment_type):
444
  input_neg_prompt=[negative_prompt] * len(input_prompt),
445
  generator=torch.Generator(device).manual_seed(seed),
446
  controller=controller,
447
- stage=1,
 
448
  lora_list=pipe_list,
449
  styleL=styleL,
450
  **kwargs)
451
-
452
- controller.reset()
453
- if pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
454
- mask1 = predict_mask(detect_model, sam, image[0], 'man', args.segment_type, confidence=0.15,
455
- threshold=0.5)
456
- else:
457
- mask1 = None
458
-
459
- if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
460
- mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.15,
461
- threshold=0.5)
462
- else:
463
- mask2 = None
464
-
465
- if mask1 is None and mask2 is None:
466
- output_list.append(image[1])
467
- else:
468
- image = sample_image(
469
- pipe,
470
- input_prompt=input_prompt,
471
- concept_models=pipe_concept,
472
- input_neg_prompt=[negative_prompt] * len(input_prompt),
473
- generator=torch.Generator(device).manual_seed(seed),
474
- controller=controller,
475
- stage=2,
476
- region_masks=[mask1, mask2],
477
- lora_list=pipe_list,
478
- styleL=styleL,
479
- **kwargs)
480
- output_list.append(image[1])
481
- else:
482
- output_list.append(None)
483
- output_list.append(spatial_condition)
484
- return output_list
485
- except:
486
- print("error")
487
- return
488
 
489
  def get_local_value_man(input):
490
  return character_man[input][0]
 
368
 
369
  @spaces.GPU
370
  def generate_image(prompt1, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style):
371
+ # try:
372
+ path1 = lorapath_man[man]
373
+ path2 = lorapath_woman[woman]
374
+ pipe_concept.unload_lora_weights()
375
+ pipe.unload_lora_weights()
376
+ pipe_list = build_model_lora(pipe_concept, path1 + "|" + path2, lorapath_styles[style], condition, args, pipe)
377
+
378
+ if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
379
+ styleL = True
380
+ else:
381
+ styleL = False
382
+
383
+ input_list = [prompt1]
384
+ condition_list = [condition_img1]
385
+ output_list = []
386
+
387
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
388
+
389
+ kwargs = {
390
+ 'height': height,
391
+ 'width': width,
392
+ }
393
+
394
+ for prompt, condition_img in zip(input_list, condition_list):
395
+ if prompt!='':
396
+ input_prompt = []
397
+ p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
398
+ if styleL:
399
+ p = styles[style] + p
400
+ input_prompt.append([p.replace("{prompt}", prompt), p.replace("{prompt}", prompt)])
401
+ if styleL:
402
+ input_prompt.append([(styles[style] + local_prompt1, character_man.get(man)[1]),
403
+ (styles[style] + local_prompt2, character_woman.get(woman)[1])])
404
+ else:
405
+ input_prompt.append([(local_prompt1, character_man.get(man)[1]),
406
+ (local_prompt2, character_woman.get(woman)[1])])
407
+
408
+ if condition == 'Human pose' and condition_img is not None:
409
+ index = ratio_list.index(
410
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
411
+ resolution = resolution_list[index]
412
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
413
+ kwargs['height'] = height
414
+ kwargs['width'] = width
415
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
416
+ spatial_condition = get_humanpose(condition_img)
417
+ elif condition == 'Canny Edge' and condition_img is not None:
418
+ index = ratio_list.index(
419
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
420
+ resolution = resolution_list[index]
421
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
422
+ kwargs['height'] = height
423
+ kwargs['width'] = width
424
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
425
+ spatial_condition = get_cannyedge(condition_img)
426
+ elif condition == 'Depth' and condition_img is not None:
427
+ index = ratio_list.index(
428
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
429
+ resolution = resolution_list[index]
430
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
431
+ kwargs['height'] = height
432
+ kwargs['width'] = width
433
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
434
+ spatial_condition = get_depth(condition_img)
435
+ else:
436
+ spatial_condition = None
437
+
438
+ kwargs['spatial_condition'] = spatial_condition
439
+ controller.reset()
440
+ image = sample_image(
441
+ pipe,
442
+ input_prompt=input_prompt,
443
+ concept_models=pipe_concept,
444
+ input_neg_prompt=[negative_prompt] * len(input_prompt),
445
+ generator=torch.Generator(device).manual_seed(seed),
446
+ controller=controller,
447
+ stage=1,
448
+ lora_list=pipe_list,
449
+ styleL=styleL,
450
+ **kwargs)
451
+
452
+ controller.reset()
453
+ if pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
454
+ mask1 = predict_mask(detect_model, sam, image[0], 'man', args.segment_type, confidence=0.15,
455
+ threshold=0.5)
456
+ else:
457
+ mask1 = None
458
+
459
+ if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
460
+ mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.15,
461
+ threshold=0.5)
462
+ else:
463
+ mask2 = None
464
+
465
+ if mask1 is None and mask2 is None:
466
+ output_list.append(image[1])
467
+ else:
468
  image = sample_image(
469
  pipe,
470
  input_prompt=input_prompt,
 
472
  input_neg_prompt=[negative_prompt] * len(input_prompt),
473
  generator=torch.Generator(device).manual_seed(seed),
474
  controller=controller,
475
+ stage=2,
476
+ region_masks=[mask1, mask2],
477
  lora_list=pipe_list,
478
  styleL=styleL,
479
  **kwargs)
480
+ output_list.append(image[1])
481
+ else:
482
+ output_list.append(None)
483
+ output_list.append(spatial_condition)
484
+ return output_list
485
+ # except:
486
+ # print("error")
487
+ # return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
 
489
  def get_local_value_man(input):
490
  return character_man[input][0]