Skip to content

Commit 5b5fd6f

Browse files
committed
Add ability to request all available gpus
Also increased PID field width for py3smi
1 parent 1986169 commit 5b5fd6f

File tree

3 files changed

+39
-30
lines changed

3 files changed

+39
-30
lines changed

py3nvml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
from py3nvml.utils import grab_gpus, get_free_gpus, get_num_procs
66

77
__all__ = ['py3nvml', 'nvidia_smi', 'grab_gpus', 'get_free_gpus', 'get_num_procs']
8-
__version__ = "0.2.3"
8+
__version__ = "0.2.4"

py3nvml/utils.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ def grab_gpus(num_gpus=1, gpu_select=None, gpu_fraction=0.95, max_procs=-1):
1717
variable. Other programs can still come along and snatch your gpu. This
1818
function is more about preventing **you** from stealing someone else's GPU.
1919
20-
If more than 1 GPU is requested but the full amount are available, then it
20+
If more than 1 GPU is requested but not all were available, then it
2121
will set the CUDA_VISIBLE_DEVICES variable to see all the available GPUs.
2222
A warning is generated in this case.
2323
2424
If one or more GPUs were requested and none were available, a Warning
25-
will be raised. Before raising it, the CUDA_VISIBLE_DEVICES will be set to a
26-
blank string. This means the calling function can ignore this warning and
25+
will be raised. Before raising it, the CUDA_VISIBLE_DEVICES will be set to
26+
a blank string. This means the calling function can ignore this warning and
2727
proceed if it chooses to only use the CPU, and it should still be protected
2828
against putting processes on a busy GPU.
2929
@@ -33,10 +33,11 @@ def grab_gpus(num_gpus=1, gpu_select=None, gpu_fraction=0.95, max_procs=-1):
3333
Parameters
3434
----------
3535
num_gpus : int
36-
How many gpus your job needs (optional)
36+
How many gpus your job needs (optional). Can set to -1 to take all
37+
remaining available GPUs.
3738
gpu_select : iterable
3839
A single int or an iterable of ints indicating gpu numbers to
39-
search through. If left blank, will search through all gpus.
40+
search through. If None, will search through all gpus.
4041
gpu_fraction : float
4142
The fractional of a gpu memory that must be free for the script to see
4243
the gpu as free. Defaults to 1. Useful if someone has grabbed a tiny
@@ -55,6 +56,8 @@ def grab_gpus(num_gpus=1, gpu_select=None, gpu_fraction=0.95, max_procs=-1):
5556
RuntimeWarning
5657
If couldn't connect with NVIDIA drivers.
5758
If 1 or more gpus were requested and none were available.
59+
Will NOT raise a RuntimeWarning for mismatch in GPU availability if
60+
`num_gpus` is -1.
5861
ValueError
5962
If the gpu_select option was not understood (can fix by leaving this
6063
field blank, providing an int or an iterable of ints).
@@ -70,15 +73,20 @@ def grab_gpus(num_gpus=1, gpu_select=None, gpu_fraction=0.95, max_procs=-1):
7073
try:
7174
py3nvml.nvmlInit()
7275
except:
73-
str_ = """ Couldn't connect to nvml drivers. Check they are installed correctly.
74-
Proceeding on cpu only..."""
76+
str_ = "Couldn't connect to nvml drivers. Check they are installed " \
77+
"correctly.\nProceeding on cpu only..."
7578
warnings.warn(str_, RuntimeWarning)
7679
logger.warn(str_)
7780
return 0
7881

7982
numDevices = py3nvml.nvmlDeviceGetCount()
8083
gpu_free = [False]*numDevices
8184

85+
warn_about_fewer_gpus = True
86+
if num_gpus == -1:
87+
num_gpus = numDevices
88+
warn_about_fewer_gpus = False
89+
8290
# Flag which gpus we can check
8391
if gpu_select is None:
8492
gpu_check = [True] * numDevices
@@ -91,8 +99,8 @@ def grab_gpus(num_gpus=1, gpu_select=None, gpu_fraction=0.95, max_procs=-1):
9199
for i in gpu_select:
92100
gpu_check[i] = True
93101
except:
94-
raise ValueError('''Please provide an int or an iterable of ints
95-
for gpu_select''')
102+
raise ValueError('Please set gpu_select to None, an int or an'
103+
'iterable of ints.')
96104

97105
# Print out GPU device info. Useful for debugging.
98106
for i in range(numDevices):
@@ -112,7 +120,7 @@ def grab_gpus(num_gpus=1, gpu_select=None, gpu_fraction=0.95, max_procs=-1):
112120
if max_procs >= 0:
113121
procs_ok = get_free_gpus(max_procs=max_procs)
114122
else:
115-
procs_ok = [True,] * numDevices
123+
procs_ok = [True, ] * numDevices
116124

117125
# Now check if any devices are suitable
118126
for i in range(numDevices):
@@ -145,9 +153,9 @@ def grab_gpus(num_gpus=1, gpu_select=None, gpu_fraction=0.95, max_procs=-1):
145153
logger.info('Using {}'.format(use_gpus))
146154
os.environ['CUDA_VISIBLE_DEVICES'] = use_gpus
147155
return num_gpus
148-
else:
156+
elif warn_about_fewer_gpus:
149157
# use everything we can.
150-
s = "Only {} GPUs found but {}".format(sum(gpu_free), num_gpus) + \
158+
s = "Only {} GPUs found but {} ".format(sum(gpu_free), num_gpus) + \
151159
"requested. Allocating these and continuing."
152160
warnings.warn(s, RuntimeWarning)
153161
logger.warn(s)

scripts/py3smi

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ from __future__ import print_function
33
from __future__ import division
44
from __future__ import absolute_import
55

6-
from py3nvml.py3nvml import *
6+
from py3nvml.py3nvml import *
77
from datetime import datetime
88
import re
99
import os
@@ -12,20 +12,20 @@ from subprocess import Popen, PIPE
1212
import argparse
1313
from time import sleep
1414
import sys
15+
from contextlib import contextmanager
1516

1617
parser = argparse.ArgumentParser(description='Print GPU stats')
17-
parser.add_argument('-l', '--loop', action='store', type=int,
18-
default=0, help='Loop period')
19-
parser.add_argument('-f', '--full', action='store_true',
20-
help='Print extended version')
21-
parser.add_argument('-w', '--width', type=int, default=77,
22-
help='Print width')
18+
parser.add_argument('-l', '--loop', action='store', type=int, default=0, help='Loop period')
19+
parser.add_argument('-f', '--full', action='store_true', help='Print extended version')
20+
parser.add_argument('-w', '--width', type=int, default=77, help='Print width')
21+
parser.add_argument('--left', action='store_true', help='Prints left part of process name')
2322

2423
COL1_WIDTH = 33
2524
COL2_WIDTH = 21
2625
COL3_WIDTH = 21
2726
WIDTH = 77
2827
LONG_FORMAT = False
28+
LEN_PROCESS_LESS_NAME = 51
2929

3030
gpu_format_col1 = '| {:>3} {:3} {:>5} {:>4} {:>11}|'
3131
gpu_format_col2 = ' {:>19} |'
@@ -76,6 +76,7 @@ def print_proc_header():
7676
print('+' + '=' * args.width + '+')
7777
return 6
7878

79+
7980
def enabled_str(x):
8081
if x == 'Enabled':
8182
return 'On'
@@ -123,7 +124,6 @@ def print_gpu_info(index, long_format=False):
123124
print(gpu_format_col3.format('', ''))
124125
return 1
125126

126-
127127
min_number = try_get_info(nvmlDeviceGetMinorNumber, h)
128128
prod_name = try_get_info(nvmlDeviceGetName, h)
129129
pers_mode = try_get_info(nvmlDeviceGetPersistenceMode, h, 0)
@@ -199,10 +199,13 @@ def print_gpu_info(index, long_format=False):
199199
return n
200200

201201

202-
def cut_proc_name(name, maxlen):
202+
def cut_proc_name(name, maxlen, left=False):
203203
if len(name) > maxlen:
204204
# return '...' + name[-maxlen+3:]
205-
return name[:maxlen-3] + '...'
205+
if left:
206+
return name[:maxlen-2] + '..'
207+
else:
208+
return '..' + name[-maxlen+2:]
206209
else:
207210
return name
208211

@@ -240,7 +243,7 @@ def get_uptime(pid):
240243
return time
241244

242245

243-
def main(full=False):
246+
def main(full=False, left=False):
244247
num_lines = 0
245248
driver_version = nvmlSystemGetDriverVersion()
246249
header_lines = print_header(driver_version, full)
@@ -276,7 +279,7 @@ def main(full=False):
276279
uptime = get_uptime(p.pid)
277280
print(proc_format.format(
278281
min_number, uname, p.pid, uptime,
279-
cut_proc_name(procname, args.width-50),
282+
cut_proc_name(procname, args.width-LEN_PROCESS_LESS_NAME, left),
280283
p.usedGpuMemory >> 20, 'MiB'))
281284
proc_lines += 1
282285
print('+' + '-' * args.width + '+')
@@ -290,22 +293,20 @@ def main(full=False):
290293

291294
if __name__ == '__main__':
292295
args = parser.parse_args()
293-
proc_format = '| {:>3} {:>11} {:>5} {:>11} {: <' + str(args.width-50) + '} {:>5}{:3<} |'
296+
proc_format = '| {:>3} {:>11} {:>7} {:>10} {: <' + str(args.width-LEN_PROCESS_LESS_NAME) + '} {:>5}{:3<} |'
294297
nvmlInit()
295-
print_lines = main(args.full)
298+
print_lines = main(args.full, args.left)
296299

297300
if args.loop > 0:
298301
try:
299302
while True:
300303
sleep(args.loop)
301304
sys.stdout.write("\033[F" * print_lines)
302-
print_lines_new = main(args.full)
305+
print_lines_new = main(args.full, args.left)
303306
if print_lines_new < print_lines:
304307
sys.stdout.write((' '*(args.width+2)+'\n')*(print_lines - print_lines_new))
305308
sys.stdout.write("\033[F" * (print_lines - print_lines_new))
306309
print_lines = print_lines_new
307310
except KeyboardInterrupt:
308311
pass
309312
nvmlShutdown()
310-
311-

0 commit comments

Comments
 (0)