forked from bakulafalls/BCI_StudyNotes
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCNN_Module.py
More file actions
102 lines (75 loc) · 2.78 KB
/
Copy pathCNN_Module.py
File metadata and controls
102 lines (75 loc) · 2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.nn as nn
from torch.utils.data import Dataset
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch import nn
# ## 帮助函数
def show_plot(iteration,accuracy,loss):
plt.plot(iteration,accuracy,loss)
plt.show()
def test_show_plot(iteration,accuracy):
plt.plot(iteration,accuracy)
plt.show()
# ## 用于配置的帮助类
class Config():
training_dir = "./data/faces/training/"
testing_dir = "./data/faces/testing/"
train_batch_size = 48 # 64
test_batch_size = 48
train_number_epochs = 100 # 100
test_number_epochs = 20
class CNNNetDataset(Dataset):
def __init__(self,file_path,target_path,transform=None,target_transform=None):
self.file_path = file_path
self.target_path = target_path
self.data = self.parse_data_file(file_path)
self.target = self.parse_target_file(target_path)
self.transform = transform
self.target_transform = target_transform
def parse_data_file(self,file_path):
data = torch.load(file_path)
return np.array(data,dtype=np.float32)
def parse_target_file(self,target_path):
target = torch.load(target_path)
return np.array(target,dtype=np.float32)
def __len__(self):
return len(self.data)
def __getitem__(self,index):
item = self.data[index,:]
target = self.target[index]
if self.transform:
item = self.transform(item)
if self.target_transform:
target = self.target_transform(target)
return item,target
class CNNNet(nn.Module):
def __init__(self):
super(CNNNet, self).__init__()
self.conv1 = nn.Conv2d(22,44,(1,3),stride=2)
self.conv2 = nn.Conv2d(44,88,(1,3),stride=2)
self.batchnorm1 = nn.BatchNorm2d(88,False)
self.pooling1 = nn.MaxPool2d(2,2)
self.conv3 = nn.Conv2d(88,44,(1,3),stride=2)
#flatten
self.fc1 = nn.Linear(88,64)
self.fc2 = nn.Linear(64,32)
self.fc3 = nn.Linear(32,4)
def forward(self,item):
x = F.elu(self.conv1(item))
x = F.elu(self.conv2(x))
x = self.batchnorm1(x)
x = self.pooling1(x)
x = F.relu(self.conv3(x))
#flatten
x = x.contiguous().view(x.size()[0],-1)
#view函数:-1为计算后的自动填充值,这个值就是batch_size,或者x = x.contiguous().view(batch_size,x.size()[0])
x = F.relu(self.fc1(x))
x = F.dropout(x,0.25)
x = F.relu(self.fc2(x))
x = F.softmax(self.fc3(x),dim=1) #self.sf =nn.Softmax(dim=1)
return x