-
Notifications
You must be signed in to change notification settings - Fork 5
Add rocm perf yml file #418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: rocm-jaxlib-v0.5.0
Are you sure you want to change the base?
Conversation
.github/workflows/rocm-perf.yml
Outdated
times.append(float(m.group(1))) | ||
if times: | ||
summary[model] = { | ||
"median_step_time": round(float(np.median(times)), 3), | ||
"steps_counted": len(times) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grab the parsed steps too with the summary
times.append(float(m.group(1))) | |
if times: | |
summary[model] = { | |
"median_step_time": round(float(np.median(times)), 3), | |
"steps_counted": len(times) | |
} | |
times.append(float(m.group(1))) | |
if times: | |
step_info = list([{"step": n, "time": t} for n,t in enumerate(times)]) | |
summary[model] = { | |
"steps": step_info, | |
"median_step_time": round(float(np.median(times)), 3), | |
"steps_counted": len(times) | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good
d2912f5
to
952cc4f
Compare
This is a port from the other performance CI PR, right? Could you add a description and link to the original PR? |
are we considering of grok and alphafold models? @Ruturaj4 @JehandadKhan Yes, we are def planning to add alphafold, however grok testing takes too much time to download weights. If grok training can be done or if there are ways to run grok faster, we are happy to add those as well! |
Why did you chose to report median step time? I don't know the rationale for that, but in general, I'm not sure that median is a correct metric here. It rejects outliers and alone it totally doesn't describe distribution of values, but that is exactly what is important to know:
TLDR: mean metric seem much better here. For the best results, I'd make 6 values: [0, 25, 50, 75, 100]% quantiles + mean too (b/c of the last bullet point) (and God forbid of stddev) |
: ) why you not reply instead of edit it...I don't even notice you had reply.. @Ruturaj4 I remember we put grok weights on a shared directory last time to avoid duplicate download last time, not sure whether we can achieve this on CI, something like a perpetual storage node for grok weights? |
952cc4f
to
e396bec
Compare
ohh I accidently edited ur reply! Yup we can do something like that as long as we have space somewhere that ci nodes can access. |
I decided to use median as per the suggestions by Jehandad. But I will also add mean time as per ur suggestions. |
e396bec
to
d6c595d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
these 4 models are all LLm, can we add one more for video? https://github.com/google-research/scenic/tree/main/scenic/projects/baselines in addition, can we add one simulation model as well? for example https://github.com/Autodesk/XLB furthermore, I have found out jax is very popular in ai for science community, and there are number of frameworks/models based on jax, wondering could we try one of them for jax perf CI? (maybe alphafold is enough?) for example this one https://github.com/sail-sg/jax_xc in general, hope we can cover LLm, video, simulation and AI4Science. |
This PR adds a new GitHub Actions workflow that:
Builds JAX with ROCm support inside a Docker container.
Runs training for the following MaxText models:
Captures
stdout
logs for each model and extracts per-step timingIgnores
step 0
(warmup) when computing metricsComputes
median_step_time
per model and saves it tosummary.json
Uploads logs and metrics as workflow artifacts
A Python analysis script (
analyze_maxtext_logs.py
) is added underjax/build/rocm/
to parse logs and generate the summary.