@@ -130,10 +130,10 @@ def main(args):
130
130
input_name = model .graph .input [0 ].name
131
131
try :
132
132
batch_size = model .graph .input [0 ].type .tensor_type .shape .dim [0 ].dim_value
133
- except :
133
+ except Exception :
134
134
batch_size = model .graph .input [0 ].type .tensor_type .shape .dim [0 ].dim_param
135
135
136
- if args .batch_size != batch_size and type (batch_size ) != str and batch_size > 0 :
136
+ if args .batch_size != batch_size and not isinstance (batch_size , str ) and batch_size > 0 :
137
137
# some rebatch model may use 'N' as dynamic axes
138
138
logger .warning (
139
139
f"Batch size { args .batch_size } doesn't match onnx model batch size { batch_size } , use model batch size { batch_size } "
@@ -169,13 +169,14 @@ def main(args):
169
169
170
170
with open (os .path .join (model_location , CSV_FILE ), "r" , encoding = "utf-8" ) as f :
171
171
reader = csv .reader (f )
172
- l = [row for row in reader ]
173
- header = l [0 ] # tag_id,name,category,count
174
- rows = l [1 :]
172
+ line = [row for row in reader ]
173
+ header = line [0 ] # tag_id,name,category,count
174
+ rows = line [1 :]
175
175
assert header [0 ] == "tag_id" and header [1 ] == "name" and header [2 ] == "category" , f"unexpected csv format: { header } "
176
176
177
- general_tags = [row [1 ] for row in rows [1 :] if row [2 ] == "0" ]
178
- character_tags = [row [1 ] for row in rows [1 :] if row [2 ] == "4" ]
177
+ rating_tags = [row [1 ] for row in rows [0 :] if row [2 ] == "9" ]
178
+ general_tags = [row [1 ] for row in rows [0 :] if row [2 ] == "0" ]
179
+ character_tags = [row [1 ] for row in rows [0 :] if row [2 ] == "4" ]
179
180
180
181
# 画像を読み込む
181
182
@@ -202,17 +203,13 @@ def run_batch(path_imgs):
202
203
probs = probs .numpy ()
203
204
204
205
for (image_path , _ ), prob in zip (path_imgs , probs ):
205
- # 最初の4つはratingなので無視する
206
- # # First 4 labels are actually ratings: pick one with argmax
207
- # ratings_names = label_names[:4]
208
- # rating_index = ratings_names["probs"].argmax()
209
- # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
206
+ combined_tags = []
207
+ rating_tag_text = ""
208
+ character_tag_text = ""
209
+ general_tag_text = ""
210
210
211
211
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
212
212
# Everything else is tags: pick any where prediction confidence > threshold
213
- combined_tags = []
214
- general_tag_text = ""
215
- character_tag_text = ""
216
213
for i , p in enumerate (prob [4 :]):
217
214
if i < len (general_tags ) and p >= args .general_threshold :
218
215
tag_name = general_tags [i ]
@@ -231,7 +228,24 @@ def run_batch(path_imgs):
231
228
if tag_name not in undesired_tags :
232
229
tag_freq [tag_name ] = tag_freq .get (tag_name , 0 ) + 1
233
230
character_tag_text += caption_separator + tag_name
234
- combined_tags .append (tag_name )
231
+ if args .character_tags_first : # insert to the beginning
232
+ combined_tags .insert (0 ,tag_name )
233
+ else :
234
+ combined_tags .append (tag_name )
235
+
236
+ #最初の4つはratingなので無視する
237
+ # First 4 labels are actually ratings: pick one with argmax
238
+ if args .use_rating_tags :
239
+ ratings_names = prob [:4 ]
240
+ rating_index = ratings_names .argmax ()
241
+ found_rating = rating_tags [rating_index ]
242
+ if args .remove_underscore and len (found_rating ) > 3 :
243
+ found_rating = found_rating .replace ("_" , " " )
244
+
245
+ if found_rating not in undesired_tags :
246
+ tag_freq [found_rating ] = tag_freq .get (found_rating , 0 ) + 1
247
+ rating_tag_text = found_rating
248
+ combined_tags .insert (0 ,found_rating ) # insert to the beginning
235
249
236
250
# 先頭のカンマを取る
237
251
if len (general_tag_text ) > 0 :
@@ -264,6 +278,7 @@ def run_batch(path_imgs):
264
278
if args .debug :
265
279
logger .info ("" )
266
280
logger .info (f"{ image_path } :" )
281
+ logger .info (f"\t Rating tags: { rating_tag_text } " )
267
282
logger .info (f"\t Character tags: { character_tag_text } " )
268
283
logger .info (f"\t General tags: { general_tag_text } " )
269
284
@@ -321,7 +336,9 @@ def run_batch(path_imgs):
321
336
322
337
def setup_parser () -> argparse .ArgumentParser :
323
338
parser = argparse .ArgumentParser ()
324
- parser .add_argument ("train_data_dir" , type = str , help = "directory for train images / 学習画像データのディレクトリ" )
339
+ parser .add_argument (
340
+ "train_data_dir" , type = str , help = "directory for train images / 学習画像データのディレクトリ"
341
+ )
325
342
parser .add_argument (
326
343
"--repo_id" ,
327
344
type = str ,
@@ -339,7 +356,9 @@ def setup_parser() -> argparse.ArgumentParser:
339
356
action = "store_true" ,
340
357
help = "force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします" ,
341
358
)
342
- parser .add_argument ("--batch_size" , type = int , default = 1 , help = "batch size in inference / 推論時のバッチサイズ" )
359
+ parser .add_argument (
360
+ "--batch_size" , type = int , default = 1 , help = "batch size in inference / 推論時のバッチサイズ"
361
+ )
343
362
parser .add_argument (
344
363
"--max_data_loader_n_workers" ,
345
364
type = int ,
@@ -378,7 +397,9 @@ def setup_parser() -> argparse.ArgumentParser:
378
397
action = "store_true" ,
379
398
help = "replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える" ,
380
399
)
381
- parser .add_argument ("--debug" , action = "store_true" , help = "debug mode" )
400
+ parser .add_argument (
401
+ "--debug" , action = "store_true" , help = "debug mode"
402
+ )
382
403
parser .add_argument (
383
404
"--undesired_tags" ,
384
405
type = str ,
@@ -388,10 +409,18 @@ def setup_parser() -> argparse.ArgumentParser:
388
409
parser .add_argument (
389
410
"--frequency_tags" , action = "store_true" , help = "Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する"
390
411
)
391
- parser .add_argument ("--onnx" , action = "store_true" , help = "use onnx model for inference / onnxモデルを推論に使用する" )
412
+ parser .add_argument (
413
+ "--onnx" , action = "store_true" , help = "use onnx model for inference / onnxモデルを推論に使用する"
414
+ )
392
415
parser .add_argument (
393
416
"--append_tags" , action = "store_true" , help = "Append captions instead of overwriting / 上書きではなくキャプションを追記する"
394
417
)
418
+ parser .add_argument (
419
+ "--use_rating_tags" , action = "store_true" , help = "Adds rating tags as the first tag" ,
420
+ )
421
+ parser .add_argument (
422
+ "--character_tags_first" , action = "store_true" , help = "Always inserts character tags before the general tags" ,
423
+ )
395
424
parser .add_argument (
396
425
"--caption_separator" ,
397
426
type = str ,
0 commit comments