Skip to content

RL Cleanup v2 #965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
f29ace4
valuefunction code
bglick13 Oct 8, 2022
1684e8b
start example scripts
bglick13 Oct 8, 2022
c757985
missing imports
bglick13 Oct 8, 2022
b315918
bug fixes and placeholder example script
bglick13 Oct 8, 2022
f01c014
add value function scheduler
bglick13 Oct 9, 2022
7b60c93
load value function from hub and get best actions in example
bglick13 Oct 9, 2022
0de435e
very close to working example
bglick13 Oct 10, 2022
a396529
larger batch size for planning
bglick13 Oct 10, 2022
713bd80
more tests
bglick13 Oct 11, 2022
e3fb50f
Merge branch 'main' into rl
bglick13 Oct 11, 2022
686069f
Merge branch 'hf_rl' into rl
bglick13 Oct 11, 2022
d9384ff
merge unet1d changes
bglick13 Oct 11, 2022
52e2668
wandb for debugging, use newer models
bglick13 Oct 11, 2022
75fe8b4
success!
bglick13 Oct 11, 2022
c7fe1dc
turns out we just need more diffusion steps
bglick13 Oct 12, 2022
a6871b1
run on modal
bglick13 Oct 12, 2022
13a443c
Merge branch 'hf_rl' into rl
bglick13 Oct 12, 2022
38616cf
merge and code cleanup
bglick13 Oct 12, 2022
d37b472
use same api for rl model
bglick13 Oct 12, 2022
aa19286
fix variance type
bglick13 Oct 13, 2022
02293e2
wrong normalization function
bglick13 Oct 13, 2022
56818e5
add tests
bglick13 Oct 17, 2022
d085725
style
bglick13 Oct 17, 2022
93fe3ef
style and quality
bglick13 Oct 17, 2022
4e378e9
edits based on comments
bglick13 Oct 18, 2022
e7e6963
style and quality
bglick13 Oct 18, 2022
4f77d89
remove unused var
bglick13 Oct 19, 2022
5de8a6a
Merge branch 'hf_rl' into rl
bglick13 Oct 19, 2022
6bd8397
hack unet1d into a value function
bglick13 Oct 20, 2022
435ad26
add pipeline
bglick13 Oct 20, 2022
5653408
fix arg order
bglick13 Oct 20, 2022
1491932
add pipeline to core library
bglick13 Oct 20, 2022
1a8098e
community pipeline
bglick13 Oct 20, 2022
0e4be75
fix couple shape bugs
bglick13 Oct 21, 2022
5ef88ef
style
bglick13 Oct 21, 2022
c6d94ce
Apply suggestions from code review
Oct 21, 2022
a9cee78
clean up comments
bglick13 Oct 21, 2022
5c8cfc2
Merge remote-tracking branch 'bglick13/rl' into rl
bglick13 Oct 21, 2022
b7fac18
convert older script to using pipeline and add readme
bglick13 Oct 21, 2022
b3edd7b
rename scripts
bglick13 Oct 21, 2022
8b01b93
style, update tests
bglick13 Oct 21, 2022
b0b8b0b
Merge branch 'hf_rl' into rl
bglick13 Oct 22, 2022
3c668a7
delete unet rl model file
bglick13 Oct 22, 2022
af26faa
remove imports in src
Oct 24, 2022
84efdac
add specific vf block and update tests
bglick13 Oct 24, 2022
3bf848f
Merge branch 'hf_rl' into rl
bglick13 Oct 24, 2022
9faf55a
style
bglick13 Oct 24, 2022
24bb52a
Update tests/test_models_unet.py
Oct 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/convert_models_diffuser_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def value_function():
down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
up_block_types=(),
out_block_type="ValueFunction",
mid_block_type="ValueFunctionMidBlock1D",
block_out_channels=(32, 64, 128, 256),
layers_per_block=1,
always_downsample=True,
Expand Down
22 changes: 22 additions & 0 deletions src/diffusers/models/unet_1d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,26 @@ class UpBlock1DNoSkip(nn.Module):
pass


class ValueFunctionMidBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, embed_dim):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.embed_dim = embed_dim

self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
self.down2 = Downsample1D(out_channels // 4, use_conv=True)

def forward(self, x, temb=None):
x = self.res1(x, temb)
x = self.down1(x)
x = self.res2(x, temb)
x = self.down2(x)
return x


class MidResTemporalBlock1D(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -307,6 +327,8 @@ def get_mid_block(mid_block_type, num_layers, in_channels, out_channels, embed_d
embed_dim=embed_dim,
add_downsample=add_downsample,
)
elif mid_block_type == "ValueFunctionMidBlock1D":
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
raise ValueError(f"{mid_block_type} does not exist.")


Expand Down
16 changes: 9 additions & 7 deletions tests/test_models_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel
from diffusers.utils import floats_tensor, slow, torch_device
from regex import subf

from .test_modeling_common import ModelTesterMixin

Expand Down Expand Up @@ -489,7 +490,7 @@ def prepare_init_args_and_inputs_for_common(self):

def test_from_pretrained_hub(self):
model, loading_info = UNet1DModel.from_pretrained(
"fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
Expand All @@ -500,7 +501,7 @@ def test_from_pretrained_hub(self):
assert image is not None, "Make sure output is not None"

def test_output_pretrained(self):
model = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128")
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
Expand All @@ -517,7 +518,8 @@ def test_output_pretrained(self):

output_slice = output[0, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584])
expected_output_slice = torch.tensor([-2.137172 , 1.1426016 , 0.3688687 , -0.766922 , 0.7303146 ,
0.11038864, -0.4760633 , 0.13270172, 0.02591348])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))

Expand Down Expand Up @@ -565,10 +567,10 @@ def prepare_init_args_and_inputs_for_common(self):

def test_from_pretrained_hub(self):
unet, loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-unet-hor32", output_loading_info=True
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
)
value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
)
self.assertIsNotNone(unet)
self.assertEqual(len(loading_info["missing_keys"]), 0)
Expand All @@ -583,7 +585,7 @@ def test_from_pretrained_hub(self):

def test_output_pretrained(self):
value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
)
torch.manual_seed(0)
if torch.cuda.is_available():
Expand All @@ -600,7 +602,7 @@ def test_output_pretrained(self):
output = value_function(noise, time_step).sample

# fmt: off
expected_output_slice = torch.tensor([207.0272] * seq_len)
expected_output_slice = torch.tensor([165.25] * seq_len)
# fmt: on
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))

Expand Down