Skip to content

Commit 95ad90c

Browse files
authored
Merge pull request #91 from jfontan/cancel-copy
downloader: make recursiveCopy context aware
2 parents f683a2a + 62830f6 commit 95ad90c

File tree

2 files changed

+57
-19
lines changed

2 files changed

+57
-19
lines changed

downloader/download_test.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ func TestAll(t *testing.T) {
128128
{"testAuthSuccess", testAuthSuccess},
129129
{"testAuthErrors", testAuthErrors},
130130
{"testContextCancelledFail", testContextCancelledFail},
131+
{"testContextCancelledPrepareRepo", testContextCancelledPrepareRepo},
131132
{"testWrongEndpointFail", testWrongEndpointFail},
132133
{"testAlreadyDownloadedFail", testAlreadyDownloadedFail},
133134
{"testDownloadConcurrentSuccess", testDownloadConcurrentSuccess},
@@ -250,7 +251,21 @@ func testContextCancelledFail(t *testing.T, h *testhelper.Helper) {
250251
}
251252
job.SetEndpoints([]string{endPoint(gitProtocol, testRepo)})
252253

253-
require.Equal(t, fmt.Errorf("context canceled"), Download(ctx, job))
254+
require.Equal(t, context.Canceled, Download(ctx, job))
255+
}
256+
257+
// testContextCancelledPrepareRepo
258+
// 1) tries to prepare a repository with a cancelled context. Previously this
259+
// caused a race condition now it should be correct.
260+
func testContextCancelledPrepareRepo(t *testing.T, h *testhelper.Helper) {
261+
ctx, cancel := context.WithCancel(context.Background())
262+
cancel()
263+
264+
testRepo := tests[0].repoIDs[0]
265+
repo, err := PrepareRepository(ctx, h.Lib, "location", testRepo,
266+
endPoint(gitProtocol, testRepo), h.TempFS, "tmp")
267+
require.Error(t, err)
268+
require.Nil(t, repo)
254269
}
255270

256271
// testWrongEndpointFail

downloader/git.go

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -264,33 +264,27 @@ func createRootedRepo(
264264
return nil, err
265265
}
266266

267-
done := make(chan struct{})
268-
go func() {
269-
err = recursiveCopy(
270-
"/", repo.FS(),
271-
clonedPath, clonedFS,
272-
)
273-
274-
close(done)
275-
}()
276-
277-
select {
278-
case <-done:
279-
case <-ctx.Done():
280-
err = ctx.Err()
281-
repo.Close()
267+
err = recursiveCopy(ctx, "/", repo.FS(), clonedPath, clonedFS)
268+
if err != nil {
282269
repo = nil
283270
}
284271

285272
return repo, err
286273
}
287274

288275
func recursiveCopy(
276+
ctx context.Context,
289277
dst string,
290278
dstFS billy.Filesystem,
291279
src string,
292280
srcFS billy.Filesystem,
293281
) error {
282+
select {
283+
case <-ctx.Done():
284+
return ctx.Err()
285+
default:
286+
}
287+
294288
stat, err := srcFS.Stat(src)
295289
if err != nil {
296290
return err
@@ -311,13 +305,13 @@ func recursiveCopy(
311305
srcPath := filepath.Join(src, file.Name())
312306
dstPath := filepath.Join(dst, file.Name())
313307

314-
err = recursiveCopy(dstPath, dstFS, srcPath, srcFS)
308+
err = recursiveCopy(ctx, dstPath, dstFS, srcPath, srcFS)
315309
if err != nil {
316310
return err
317311
}
318312
}
319313
} else {
320-
err = copyFile(dst, dstFS, src, srcFS, stat.Mode())
314+
err = copyFile(ctx, dst, dstFS, src, srcFS, stat.Mode())
321315
if err != nil {
322316
return err
323317
}
@@ -327,12 +321,19 @@ func recursiveCopy(
327321
}
328322

329323
func copyFile(
324+
ctx context.Context,
330325
dst string,
331326
dstFS billy.Filesystem,
332327
src string,
333328
srcFS billy.Filesystem,
334329
mode os.FileMode,
335330
) error {
331+
select {
332+
case <-ctx.Done():
333+
return ctx.Err()
334+
default:
335+
}
336+
336337
_, err := srcFS.Stat(src)
337338
if err != nil {
338339
return err
@@ -350,7 +351,7 @@ func copyFile(
350351
}
351352
defer fd.Close()
352353

353-
_, err = io.Copy(fd, fo)
354+
_, err = io.Copy(fd, newContextReader(ctx, fo))
354355
if err != nil {
355356
fd.Close()
356357
dstFS.Remove(dst)
@@ -359,3 +360,25 @@ func copyFile(
359360

360361
return nil
361362
}
363+
364+
type contextReader struct {
365+
reader io.Reader
366+
ctx context.Context
367+
}
368+
369+
func newContextReader(ctx context.Context, reader io.Reader) *contextReader {
370+
return &contextReader{
371+
ctx: ctx,
372+
reader: reader,
373+
}
374+
}
375+
376+
func (c *contextReader) Read(p []byte) (n int, err error) {
377+
select {
378+
case <-c.ctx.Done():
379+
return 0, c.ctx.Err()
380+
default:
381+
}
382+
383+
return c.reader.Read(p)
384+
}

0 commit comments

Comments
 (0)