Skip to content

Commit 6ba8428

Browse files
authored
Merge pull request #1216 from Disty0/dev
Rating support for WD Tagger
2 parents 434dc40 + bc586ce commit 6ba8428

File tree

3 files changed

+68
-25
lines changed

3 files changed

+68
-25
lines changed

finetune/tag_images_by_wd14_tagger.py

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ def main(args):
130130
input_name = model.graph.input[0].name
131131
try:
132132
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
133-
except:
133+
except Exception:
134134
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
135135

136-
if args.batch_size != batch_size and type(batch_size) != str and batch_size > 0:
136+
if args.batch_size != batch_size and not isinstance(batch_size, str) and batch_size > 0:
137137
# some rebatch model may use 'N' as dynamic axes
138138
logger.warning(
139139
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
@@ -169,13 +169,14 @@ def main(args):
169169

170170
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
171171
reader = csv.reader(f)
172-
l = [row for row in reader]
173-
header = l[0] # tag_id,name,category,count
174-
rows = l[1:]
172+
line = [row for row in reader]
173+
header = line[0] # tag_id,name,category,count
174+
rows = line[1:]
175175
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
176176

177-
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
178-
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
177+
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
178+
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
179+
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
179180

180181
# 画像を読み込む
181182

@@ -202,17 +203,13 @@ def run_batch(path_imgs):
202203
probs = probs.numpy()
203204

204205
for (image_path, _), prob in zip(path_imgs, probs):
205-
# 最初の4つはratingなので無視する
206-
# # First 4 labels are actually ratings: pick one with argmax
207-
# ratings_names = label_names[:4]
208-
# rating_index = ratings_names["probs"].argmax()
209-
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
206+
combined_tags = []
207+
rating_tag_text = ""
208+
character_tag_text = ""
209+
general_tag_text = ""
210210

211211
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
212212
# Everything else is tags: pick any where prediction confidence > threshold
213-
combined_tags = []
214-
general_tag_text = ""
215-
character_tag_text = ""
216213
for i, p in enumerate(prob[4:]):
217214
if i < len(general_tags) and p >= args.general_threshold:
218215
tag_name = general_tags[i]
@@ -231,7 +228,24 @@ def run_batch(path_imgs):
231228
if tag_name not in undesired_tags:
232229
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
233230
character_tag_text += caption_separator + tag_name
234-
combined_tags.append(tag_name)
231+
if args.character_tags_first: # insert to the beginning
232+
combined_tags.insert(0,tag_name)
233+
else:
234+
combined_tags.append(tag_name)
235+
236+
#最初の4つはratingなので無視する
237+
# First 4 labels are actually ratings: pick one with argmax
238+
if args.use_rating_tags:
239+
ratings_names = prob[:4]
240+
rating_index = ratings_names.argmax()
241+
found_rating = rating_tags[rating_index]
242+
if args.remove_underscore and len(found_rating) > 3:
243+
found_rating = found_rating.replace("_", " ")
244+
245+
if found_rating not in undesired_tags:
246+
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
247+
rating_tag_text = found_rating
248+
combined_tags.insert(0,found_rating) # insert to the beginning
235249

236250
# 先頭のカンマを取る
237251
if len(general_tag_text) > 0:
@@ -264,6 +278,7 @@ def run_batch(path_imgs):
264278
if args.debug:
265279
logger.info("")
266280
logger.info(f"{image_path}:")
281+
logger.info(f"\tRating tags: {rating_tag_text}")
267282
logger.info(f"\tCharacter tags: {character_tag_text}")
268283
logger.info(f"\tGeneral tags: {general_tag_text}")
269284

@@ -321,7 +336,9 @@ def run_batch(path_imgs):
321336

322337
def setup_parser() -> argparse.ArgumentParser:
323338
parser = argparse.ArgumentParser()
324-
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
339+
parser.add_argument(
340+
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
341+
)
325342
parser.add_argument(
326343
"--repo_id",
327344
type=str,
@@ -339,7 +356,9 @@ def setup_parser() -> argparse.ArgumentParser:
339356
action="store_true",
340357
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
341358
)
342-
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
359+
parser.add_argument(
360+
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
361+
)
343362
parser.add_argument(
344363
"--max_data_loader_n_workers",
345364
type=int,
@@ -378,7 +397,9 @@ def setup_parser() -> argparse.ArgumentParser:
378397
action="store_true",
379398
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
380399
)
381-
parser.add_argument("--debug", action="store_true", help="debug mode")
400+
parser.add_argument(
401+
"--debug", action="store_true", help="debug mode"
402+
)
382403
parser.add_argument(
383404
"--undesired_tags",
384405
type=str,
@@ -388,10 +409,18 @@ def setup_parser() -> argparse.ArgumentParser:
388409
parser.add_argument(
389410
"--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する"
390411
)
391-
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
412+
parser.add_argument(
413+
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
414+
)
392415
parser.add_argument(
393416
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
394417
)
418+
parser.add_argument(
419+
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag",
420+
)
421+
parser.add_argument(
422+
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags",
423+
)
395424
parser.add_argument(
396425
"--caption_separator",
397426
type=str,

library/ipex/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def ipex_init(): # pylint: disable=too-many-statements
3232
torch.cuda.FloatTensor = torch.xpu.FloatTensor
3333
torch.Tensor.cuda = torch.Tensor.xpu
3434
torch.Tensor.is_cuda = torch.Tensor.is_xpu
35+
torch.nn.Module.cuda = torch.nn.Module.xpu
3536
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
3637
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
3738
torch.cuda._initialized = torch.xpu.lazy_init._initialized
@@ -147,9 +148,9 @@ def ipex_init(): # pylint: disable=too-many-statements
147148

148149
# C
149150
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
150-
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
151-
ipex._C._DeviceProperties.major = 2023
152-
ipex._C._DeviceProperties.minor = 2
151+
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
152+
ipex._C._DeviceProperties.major = 2024
153+
ipex._C._DeviceProperties.minor = 0
153154

154155
# Fix functions with ipex:
155156
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]

library/ipex/hijacks.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,16 @@ def Tensor_cuda(self, device=None, *args, **kwargs):
190190
else:
191191
return original_Tensor_cuda(self, device, *args, **kwargs)
192192

193+
original_Tensor_pin_memory = torch.Tensor.pin_memory
194+
@wraps(torch.Tensor.pin_memory)
195+
def Tensor_pin_memory(self, device=None, *args, **kwargs):
196+
if device is None:
197+
device = "xpu"
198+
if check_device(device):
199+
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
200+
else:
201+
return original_Tensor_pin_memory(self, device, *args, **kwargs)
202+
193203
original_UntypedStorage_init = torch.UntypedStorage.__init__
194204
@wraps(torch.UntypedStorage.__init__)
195205
def UntypedStorage_init(*args, device=None, **kwargs):
@@ -259,17 +269,20 @@ def torch_Generator(device=None):
259269
original_torch_load = torch.load
260270
@wraps(torch.load)
261271
def torch_load(f, map_location=None, *args, **kwargs):
272+
if map_location is None:
273+
map_location = "xpu"
262274
if check_device(map_location):
263-
return original_torch_load(f, map_location=return_xpu(map_location), *args, **kwargs)
275+
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
264276
else:
265-
return original_torch_load(f, map_location=map_location, *args, **kwargs)
277+
return original_torch_load(f, *args, map_location=map_location, **kwargs)
266278

267279

268280
# Hijack Functions:
269281
def ipex_hijacks():
270282
torch.tensor = torch_tensor
271283
torch.Tensor.to = Tensor_to
272284
torch.Tensor.cuda = Tensor_cuda
285+
torch.Tensor.pin_memory = Tensor_pin_memory
273286
torch.UntypedStorage.__init__ = UntypedStorage_init
274287
torch.UntypedStorage.cuda = UntypedStorage_cuda
275288
torch.empty = torch_empty

0 commit comments

Comments
 (0)