-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfile.go
More file actions
291 lines (244 loc) · 7.69 KB
/
file.go
File metadata and controls
291 lines (244 loc) · 7.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
package ragserver
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"fmt"
"io"
"mime/multipart"
"net/http"
"strings"
"time"
"github.com/gofrs/uuid/v5"
"github.com/RichardKnop/ragserver/pkg/authz"
)
const (
MB = 1 << 20
MaxFileSize = 20 * MB
)
type FileID struct{ uuid.UUID }
func NewFileID() FileID {
return FileID{uuid.Must(uuid.NewV4())}
}
type AuthorID struct{ uuid.UUID }
func NewAuthorID() AuthorID {
return AuthorID{uuid.Must(uuid.NewV4())}
}
type FileStatus string
const (
FileStatusUploaded FileStatus = "UPLOADED"
FileStatusProcessing FileStatus = "PROCESSING"
FileStatusProcessedSuccessfully FileStatus = "PROCESSED_SUCCESSFULLY"
FileStatusProcessingFailed FileStatus = "PROCESSING_FAILED"
)
type File struct {
ID FileID
AuthorID AuthorID
FileName string
ContentType string
Extension string
Size int64
Hash string
Embedder string // adapter used to generate embeddings for this file
Retriever string // adapter used to store/retrieve embeddings for this file
Status FileStatus
StatusMessage string
Created time.Time
Updated time.Time
Documents []Document
}
// CompleteWithStatus changes the status of a file to a completion status,
// either FileStatusProcessedSuccessfully or FileStatusProcessingFailed.
func (f *File) CompleteWithStatus(newStatus FileStatus, message string, updatedAt time.Time) error {
if f.Status != FileStatusProcessing {
return fmt.Errorf("cannot change status from %s to %s", f.Status, newStatus)
}
f.Status = newStatus
f.StatusMessage = message
f.Updated = updatedAt
return nil
}
type FileFilter struct {
Embedder string
Retriever string
Status FileStatus
LastUpdatedBefore time.Time
ScreeningID ScreeningID
Hash string
Lock bool
}
type TempFile interface {
io.ReadSeekCloser
io.Writer
Name() string
}
func (rs *ragServer) CreateFile(ctx context.Context, principal authz.Principal, file io.ReadSeeker, header *multipart.FileHeader) (*File, error) {
tempFile, err := rs.filestorage.NewTempFile()
if err != nil {
return nil, fmt.Errorf("error creating temp file: %v", err)
}
defer tempFile.Close()
contentType, ok, err := checkContentType(file)
if err != nil {
return nil, fmt.Errorf("error checking content type: %w", err)
}
if !ok {
return nil, fmt.Errorf("invalid file type")
}
// Reset the file offset to the beginning for further reading
_, err = file.Seek(0, io.SeekStart)
if err != nil {
return nil, fmt.Errorf("error seeking file to start: %w", err)
}
rs.logger.Sugar().With("filename", header.Filename, "size", header.Size, "header", header.Header).Infof("uploading file")
hashWriter := sha256.New()
newReader := io.TeeReader(file, hashWriter)
fileSize, err := io.Copy(tempFile, newReader)
if err != nil {
return nil, fmt.Errorf("error copying to temp file: %w", err)
}
fileHash := hex.EncodeToString(hashWriter.Sum(nil))
exists, err := rs.filestorage.Exists(fileHash)
if err != nil {
return nil, fmt.Errorf("error checking if file exists: %w", err)
}
if !exists {
// Reset the temp file offset to the beginning for further reading
_, err := tempFile.Seek(0, io.SeekStart)
if err != nil {
return nil, fmt.Errorf("error seeking temp file to start: %w", err)
}
if err := rs.filestorage.Write(fileHash, tempFile); err != nil {
return nil, fmt.Errorf("error writing to file storage: %w", err)
}
}
defer rs.filestorage.DeleteTempFile(tempFile.Name())
aFile := &File{
ID: NewFileID(),
AuthorID: AuthorID{principal.ID().UUID},
FileName: header.Filename,
ContentType: contentType,
Size: fileSize,
Hash: fileHash,
Embedder: rs.embedder.Name(),
Retriever: rs.retriever.Name(),
Status: FileStatusUploaded,
Created: rs.now(),
Updated: rs.now(),
}
switch contentType {
case "application/pdf":
aFile.Extension = strings.TrimPrefix(contentType, "application/")
case "image/jpeg", "image/png":
return nil, fmt.Errorf("image file processing not implemented yet")
}
if err := rs.store.Transactional(ctx, &sql.TxOptions{}, func(ctx context.Context) error {
if err := rs.store.SavePrincipal(ctx, principal); err != nil {
return fmt.Errorf("error saving principal: %w", err)
}
if err := rs.store.SaveFiles(ctx, aFile); err != nil {
return fmt.Errorf("error saving file: %w", err)
}
return nil
}); err != nil {
return nil, fmt.Errorf("error saving file: %v", err)
}
return aFile, nil
}
func (rs *ragServer) ListFiles(ctx context.Context, principal authz.Principal) ([]*File, error) {
var files []*File
if err := rs.store.Transactional(ctx, &sql.TxOptions{}, func(ctx context.Context) error {
var err error
files, err = rs.store.ListFiles(ctx, FileFilter{}, rs.filePpartial(), SortParams{})
if err != nil {
return err
}
return nil
}); err != nil {
return nil, err
}
return files, nil
}
func (rs *ragServer) FindFile(ctx context.Context, principal authz.Principal, id FileID) (*File, error) {
var aFile *File
if err := rs.store.Transactional(ctx, &sql.TxOptions{}, func(ctx context.Context) error {
var err error
aFile, err = rs.store.FindFile(ctx, id, rs.filePpartial())
if err != nil {
return err
}
return nil
}); err != nil {
return nil, err
}
return aFile, nil
}
func (rs *ragServer) DeleteFile(ctx context.Context, principal authz.Principal, id FileID) error {
rs.logger.Sugar().With("id", id).Info("deleting file")
if err := rs.store.Transactional(ctx, &sql.TxOptions{}, func(ctx context.Context) error {
var err error
aFile, err := rs.store.FindFile(ctx, id, rs.filePpartial())
if err != nil {
return err
}
screenings, err := rs.store.ListScreenings(ctx, ScreeningFilter{
FileID: id,
}, authz.NilPartial, SortParams{})
if err != nil {
return fmt.Errorf("error listing screenings: %w", err)
}
if len(screenings) > 0 {
return fmt.Errorf("cannot delete file with associated screenings")
}
if aFile.Status == FileStatusUploaded || aFile.Status == FileStatusProcessing {
return fmt.Errorf("cannot delete file in status %s", aFile.Status)
}
filesWithHash, err := rs.store.ListFiles(ctx, FileFilter{
Hash: aFile.Hash,
}, authz.NilPartial, SortParams{})
if err := rs.store.DeleteFiles(ctx, aFile); err != nil {
return fmt.Errorf("error deleting file: %w", err)
}
// Only delete from file storage if no other files share the same hash
if len(filesWithHash) == 1 {
if err := rs.filestorage.Delete(aFile.Hash); err != nil {
return fmt.Errorf("error deleting file from storage: %w", err)
}
}
if err := rs.retriever.DeleteFileDocuments(ctx, id); err != nil {
return fmt.Errorf("error deleting file documents from retriever: %w", err)
}
return nil
}); err != nil {
return err
}
return nil
}
var allowedContentTypes = map[string]struct{}{
"application/pdf": {},
// "image/jpeg": {},
// "image/png": {},
// "image/gif": {},
}
func checkContentType(reader io.Reader) (string, bool, error) {
contentType, err := detectContentType(reader)
if err != nil {
return "", false, err
}
_, ok := allowedContentTypes[contentType]
return contentType, ok, nil
}
func detectContentType(reader io.Reader) (string, error) {
// At most the first 512 bytes of data are used:
// https://golang.org/src/net/http/sniff.go?s=646:688#L11
buff := make([]byte, 512)
bytesRead, err := reader.Read(buff)
if err != nil && err != io.EOF {
return "", err
}
// Slice to remove fill-up zero values which cause a wrong content type detection in the next step
// (for example a text file which is smaller than 512 bytes)
buff = buff[:bytesRead]
return http.DetectContentType(buff), nil
}