Pytorch 1.7.1
更新时间:2023-01-18
Pytorch
训练代码
基于Pytorch框架的MNIST图像分类示例代码,数据集请点击这里下载。
单机训练时(计算节点等于1),示例代码如下:
Python
1import argparse
2import torch
3import torch.nn as nn
4import torch.nn.functional as F
5import torch.optim as optim
6import torch.utils.data as data
7from torchvision import transforms
8import codecs
9import errno
10import gzip
11import numpy as np
12import os
13from PIL import Image
14# Training settings
15parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
16parser.add_argument('--train-dir', type=str, default='./train_data',
17 help='input data dir for training (default: ./train_data)')
18parser.add_argument('--test-dir', type=str, default='./test_data',
19 help='input data dir for test (default: ./test_data)')
20parser.add_argument('--output-dir', type=str, default='./output',
21 help='output dir for custom job (default: ./output)')
22parser.add_argument('--batch-size', type=int, default=64, metavar='N',
23 help='input batch size for training (default: 64)')
24parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
25 help='input batch size for testing (default: 64)')
26parser.add_argument('--epochs', type=int, default=10, metavar='N',
27 help='number of epochs to train (default: 10)')
28parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
29 help='learning rate (default: 0.01)')
30parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
31 help='SGD momentum (default: 0.5)')
32parser.add_argument('--no-cuda', action='store_true', default=False,
33 help='disables CUDA training')
34parser.add_argument('--log-interval', type=int, default=10, metavar='N',
35 help='how many batches to wait before logging training status')
36# 定义MNIST数据集的dataset
37class MNIST(data.Dataset):
38 """
39 MNIST dataset
40 """
41 training_file = 'training.pt'
42 test_file = 'test.pt'
43 classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
44 '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
45 def __init__(self, root, train=True, transform=None, target_transform=None):
46 self.root = os.path.expanduser(root)
47 self.transform = transform
48 self.target_transform = target_transform
49 self.train = train # training set or test set
50 self.preprocess(root, train, False)
51 if self.train:
52 data_file = self.training_file
53 else:
54 data_file = self.test_file
55 self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
56
57 def __getitem__(self, index):
58 """
59 Args:
60 index (int): Index
61 Returns:
62 tuple: (image, target) where target is index of the target class.
63 """
64 img, target = self.data[index], int(self.targets[index])
65 # doing this so that it is consistent with all other datasets
66 # to return a PIL Image
67 img = Image.fromarray(img.numpy(), mode='L')
68 if self.transform is not None:
69 img = self.transform(img)
70 if self.target_transform is not None:
71 target = self.target_transform(target)
72
73 return img, target
74 def __len__(self):
75 return len(self.data)
76
77 @property
78 def raw_folder(self):
79 """
80 raw folder
81 """
82 return os.path.join('/tmp', 'raw')
83
84 @property
85 def processed_folder(self):
86 """
87 processed folder
88 """
89 return os.path.join('/tmp', 'processed')
90
91 # data preprocessing
92 def preprocess(self, train_dir, train, remove_finished=False):
93 """
94 preprocess
95 """
96 makedir_exist_ok(self.raw_folder)
97 makedir_exist_ok(self.processed_folder)
98 train_list = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz']
99 test_list = ['t10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
100 zip_list = train_list if train else test_list
101 for zip_file in zip_list:
102 print('Extracting {}'.format(zip_file))
103 zip_file_path = os.path.join(train_dir, zip_file)
104 raw_folder_path = os.path.join(self.raw_folder, zip_file)
105 with open(raw_folder_path.replace('.gz', ''), 'wb') as out_f, gzip.GzipFile(zip_file_path) as zip_f:
106 out_f.write(zip_f.read())
107 if remove_finished:
108 os.unlink(zip_file_path)
109 if train:
110 training_set = (
111 read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
112 read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
113 )
114 with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
115 torch.save(training_set, f)
116 else:
117 test_set = (
118 read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
119 read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
120 )
121 with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
122 torch.save(test_set, f)
123
124def get_int(b):
125 """
126 get int
127 """
128 return int(codecs.encode(b, 'hex'), 16)
129def read_label_file(path):
130 """
131 read label file
132 """
133 with open(path, 'rb') as f:
134 data = f.read()
135 assert get_int(data[:4]) == 2049
136 length = get_int(data[4:8])
137 parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
138 return torch.from_numpy(parsed).view(length).long()
139def read_image_file(path):
140 """
141 read image file
142 """
143 with open(path, 'rb') as f:
144 data = f.read()
145 assert get_int(data[:4]) == 2051
146 length = get_int(data[4:8])
147 num_rows = get_int(data[8:12])
148 num_cols = get_int(data[12:16])
149 parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
150 return torch.from_numpy(parsed).view(length, num_rows, num_cols)
151def makedir_exist_ok(dirpath):
152 """
153 Python2 support for os.makedirs(.., exist_ok=True)
154 """
155 try:
156 os.makedirs(dirpath)
157 except OSError as e:
158 if e.errno == errno.EEXIST:
159 pass
160 else:
161 raise
162# 定义网络模型
163class Net(nn.Module):
164 """
165 Net
166 """
167 def __init__(self):
168 super(Net, self).__init__()
169 self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
170 self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
171 self.conv2_drop = nn.Dropout2d()
172 self.fc1 = nn.Linear(320, 50)
173 self.fc2 = nn.Linear(50, 10)
174 def forward(self, x):
175 """
176 forward
177 """
178 x = F.relu(F.max_pool2d(self.conv1(x), 2))
179 x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
180 x = x.view(-1, 320)
181 x = F.relu(self.fc1(x))
182 x = F.dropout(x, training=self.training)
183 x = self.fc2(x)
184 return F.log_softmax(x)
185def train(epoch):
186 """
187 train
188 """
189 model.train()
190 for batch_idx, (data, target) in enumerate(train_loader):
191 if args.cuda:
192 data, target = data.cuda(), target.cuda()
193 optimizer.zero_grad()
194 output = model(data) # 获取预测值
195 loss = F.nll_loss(output, target) # 计算loss
196 loss.backward()
197 optimizer.step()
198 if batch_idx % args.log_interval == 0:
199 print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
200 epoch, batch_idx, len(train_loader),
201 100. * batch_idx / len(train_loader), loss.item()))
202
203def test():
204 """
205 test
206 """
207 model.eval()
208 test_loss = 0.
209 test_accuracy = 0.
210 for data, target in test_loader:
211 if args.cuda:
212 data, target = data.cuda(), target.cuda()
213 output = model(data)
214 # sum up batch loss
215 test_loss += F.nll_loss(output, target, size_average=False).item()
216 # get the index of the max log-probability
217 pred = output.data.max(1, keepdim=True)[1]
218 test_accuracy += pred.eq(target.data.view_as(pred)).cpu().float().sum()
219 test_loss /= len(test_loader) * args.test_batch_size
220 test_accuracy /= len(test_loader) * args.test_batch_size
221 print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
222 test_loss, 100. * test_accuracy))
223def save():
224 """
225 save
226 """
227 if not os.path.exists(args.output_dir):
228 os.makedirs(args.output_dir)
229 # 保存模型
230 torch.save(model.state_dict(), os.path.join(args.output_dir, 'model.pkl'))
231if __name__ == '__main__':
232 args = parser.parse_args()
233 args.cuda = not args.no_cuda and torch.cuda.is_available()
234 # 若无测试集,训练集做验证集
235 if not os.path.exists(args.test_dir) or not os.listdir(args.test_dir):
236 args.test_dir = args.train_dir
237 # 将数据进行转化,从PIL.Image/numpy.ndarray的数据进转化为torch.FloadTensor
238 trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
239 train_set = MNIST(root=args.train_dir, train=True, transform=trans)
240 test_set = MNIST(root=args.test_dir, train=False, transform=trans)
241 # 定义data reader
242 train_loader = torch.utils.data.DataLoader(
243 dataset=train_set,
244 batch_size=args.batch_size,
245 shuffle=True)
246 test_loader = torch.utils.data.DataLoader(
247 dataset=test_set,
248 batch_size=args.test_batch_size,
249 shuffle=False)
250 # 选择模型
251 model = Net()
252 if args.cuda:
253 # Move model to GPU.
254 model.cuda()
255 print(model)
256 # 选择优化器
257 optimizer = optim.SGD(model.parameters(), lr=args.lr,
258 momentum=args.momentum)
259 for epoch in range(1, args.epochs + 1):
260 train(epoch)
261 test()
262 save()
分布式训练时(计算节点大于1),示例代码如下:
说明:demo分布式程序没有做数据的分片操作,仅供参考
Python
1import argparse
2import torch
3import torch.nn as nn
4import torch.nn.functional as F
5import torch.optim as optim
6import torch.utils.data as data
7from torchvision import datasets, transforms
8import codecs
9import errno
10import gzip
11import numpy as np
12import os
13from PIL import Image
14import torch.multiprocessing as mp
15import torch.utils.data.distributed
16import horovod.torch as hvd
17# Training settings
18parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
19parser.add_argument('--train-dir', type=str, default='./train_data',
20 help='input data dir for training (default: ./train_data)')
21parser.add_argument('--test-dir', type=str, default='./test_data',
22 help='input data dir for test (default: ./test_data)')
23parser.add_argument('--output-dir', type=str, default='./output',
24 help='output dir for custom job (default: ./output)')
25parser.add_argument('--batch-size', type=int, default=64, metavar='N',
26 help='input batch size for training (default: 64)')
27parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
28 help='input batch size for testing (default: 64)')
29parser.add_argument('--epochs', type=int, default=10, metavar='N',
30 help='number of epochs to train (default: 10)')
31parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
32 help='learning rate (default: 0.01)')
33parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
34 help='SGD momentum (default: 0.5)')
35parser.add_argument('--no-cuda', action='store_true', default=False,
36 help='disables CUDA training')
37parser.add_argument('--seed', type=int, default=42, metavar='S',
38 help='random seed (default: 42)')
39parser.add_argument('--log-interval', type=int, default=10, metavar='N',
40 help='how many batches to wait before logging training status')
41parser.add_argument('--fp16-allreduce', action='store_true', default=False,
42 help='use fp16 compression during allreduce')
43parser.add_argument('--use-adasum', action='store_true', default=False,
44 help='use adasum algorithm to do reduction')
45parser.add_argument('--gradient-predivide-factor', type=float, default=1.0,
46 help='apply gradient predivide factor in optimizer (default: 1.0)')
47# 定义MNIST数据集的dataset
48class MNIST(data.Dataset):
49 """
50 MNIST dataset
51 """
52 training_file = 'training.pt'
53 test_file = 'test.pt'
54 classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
55 '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
56 def __init__(self, root, train=True, transform=None, target_transform=None):
57 self.root = os.path.expanduser(root)
58 self.transform = transform
59 self.target_transform = target_transform
60 self.train = train # training set or test set
61 self.preprocess(root, train, False)
62 if self.train:
63 data_file = self.training_file
64 else:
65 data_file = self.test_file
66 self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
67
68 def __getitem__(self, index):
69 """
70 Args:
71 index (int): Index
72 Returns:
73 tuple: (image, target) where target is index of the target class.
74 """
75 img, target = self.data[index], int(self.targets[index])
76 # doing this so that it is consistent with all other datasets
77 # to return a PIL Image
78 img = Image.fromarray(img.numpy(), mode='L')
79 if self.transform is not None:
80 img = self.transform(img)
81 if self.target_transform is not None:
82 target = self.target_transform(target)
83
84 return img, target
85 def __len__(self):
86 return len(self.data)
87
88 @property
89 def raw_folder(self):
90 """
91 raw folder
92 """
93 return os.path.join('/tmp', 'raw')
94
95 @property
96 def processed_folder(self):
97 """
98 processed folder
99 """
100 return os.path.join('/tmp', 'processed')
101
102 # data preprocessing
103 def preprocess(self, train_dir, train, remove_finished=False):
104 """
105 preprocess
106 """
107 makedir_exist_ok(self.raw_folder)
108 makedir_exist_ok(self.processed_folder)
109 train_list = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz']
110 test_list = ['t10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
111 zip_list = train_list if train else test_list
112 for zip_file in zip_list:
113 print('Extracting {}'.format(zip_file))
114 zip_file_path = os.path.join(train_dir, zip_file)
115 raw_folder_path = os.path.join(self.raw_folder, zip_file)
116 with open(raw_folder_path.replace('.gz', ''), 'wb') as out_f, gzip.GzipFile(zip_file_path) as zip_f:
117 out_f.write(zip_f.read())
118 if remove_finished:
119 os.unlink(zip_file_path)
120 if train:
121 training_set = (
122 read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
123 read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
124 )
125 with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
126 torch.save(training_set, f)
127 else:
128 test_set = (
129 read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
130 read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
131 )
132 with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
133 torch.save(test_set, f)
134def get_int(b):
135 """
136 get int
137 """
138 return int(codecs.encode(b, 'hex'), 16)
139def read_label_file(path):
140 """
141 read label file
142 """
143 with open(path, 'rb') as f:
144 data = f.read()
145 assert get_int(data[:4]) == 2049
146 length = get_int(data[4:8])
147 parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
148 return torch.from_numpy(parsed).view(length).long()
149def read_image_file(path):
150 """
151 read image file
152 """
153 with open(path, 'rb') as f:
154 data = f.read()
155 assert get_int(data[:4]) == 2051
156 length = get_int(data[4:8])
157 num_rows = get_int(data[8:12])
158 num_cols = get_int(data[12:16])
159 parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
160 return torch.from_numpy(parsed).view(length, num_rows, num_cols)
161def makedir_exist_ok(dirpath):
162 """
163 Python2 support for os.makedirs(.., exist_ok=True)
164 """
165 try:
166 os.makedirs(dirpath)
167 except OSError as e:
168 if e.errno == errno.EEXIST:
169 pass
170 else:
171 raise
172# 定义网络模型
173class Net(nn.Module):
174 """
175 Net
176 """
177 def __init__(self):
178 super(Net, self).__init__()
179 self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
180 self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
181 self.conv2_drop = nn.Dropout2d()
182 self.fc1 = nn.Linear(320, 50)
183 self.fc2 = nn.Linear(50, 10)
184 def forward(self, x):
185 """
186 forward
187 """
188 x = F.relu(F.max_pool2d(self.conv1(x), 2))
189 x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
190 x = x.view(-1, 320)
191 x = F.relu(self.fc1(x))
192 x = F.dropout(x, training=self.training)
193 x = self.fc2(x)
194 return F.log_softmax(x)
195def train(epoch):
196 """
197 train
198 """
199 model.train()
200 # Horovod: set epoch to sampler for shuffling.
201 train_sampler.set_epoch(epoch)
202 for batch_idx, (data, target) in enumerate(train_loader):
203 if args.cuda:
204 data, target = data.cuda(), target.cuda()
205 optimizer.zero_grad()
206 output = model(data)
207 loss = F.nll_loss(output, target)
208 loss.backward()
209 optimizer.step()
210 if batch_idx % args.log_interval == 0:
211 # Horovod: use train_sampler to determine the number of examples in
212 # this worker's partition.
213 print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
214 epoch, batch_idx * len(data), len(train_sampler),
215 100. * batch_idx / len(train_loader), loss.item()))
216def metric_average(val, name):
217 """
218 metric average
219 """
220 tensor = torch.tensor(val)
221 avg_tensor = hvd.allreduce(tensor, name=name)
222 return avg_tensor.item()
223def test():
224 """
225 test
226 """
227 model.eval()
228 test_loss = 0.
229 test_accuracy = 0.
230 for data, target in test_loader:
231 if args.cuda:
232 data, target = data.cuda(), target.cuda()
233 output = model(data)
234 # sum up batch loss
235 test_loss += F.nll_loss(output, target, size_average=False).item()
236 # get the index of the max log-probability
237 pred = output.data.max(1, keepdim=True)[1]
238 test_accuracy += pred.eq(target.data.view_as(pred)).cpu().float().sum()
239 # Horovod: use test_sampler to determine the number of examples in
240 # this worker's partition.
241 test_loss /= len(test_sampler)
242 test_accuracy /= len(test_sampler)
243 # Horovod: average metric values across workers.
244 test_loss = metric_average(test_loss, 'avg_loss')
245 test_accuracy = metric_average(test_accuracy, 'avg_accuracy')
246 # Horovod: print output only on first rank.
247 if hvd.rank() == 0:
248 print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
249 test_loss, 100. * test_accuracy))
250def save():
251 """
252 save
253 """
254 if not os.path.exists(args.output_dir):
255 os.makedirs(args.output_dir)
256 # 保存模型
257 # Horovod: save model only on first rank.
258 if hvd.rank() == 0:
259 torch.save(model.state_dict(), os.path.join(args.output_dir, 'model.pkl'))
260if __name__ == '__main__':
261 args = parser.parse_args()
262 args.cuda = not args.no_cuda and torch.cuda.is_available()
263 # Horovod: initialize library.
264 hvd.init()
265 torch.manual_seed(args.seed)
266 if args.cuda:
267 # Horovod: pin GPU to local rank.
268 torch.cuda.set_device(hvd.local_rank())
269 torch.cuda.manual_seed(args.seed)
270 # Horovod: limit # of CPU threads to be used per worker.
271 torch.set_num_threads(1)
272 kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
273 # When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent
274 # issues with Infiniband implementations that are not fork-safe
275 if (kwargs.get('num_workers', 0) > 0 and hasattr(mp, '_supports_context') and
276 mp._supports_context and 'forkserver' in mp.get_all_start_methods()):
277 kwargs['multiprocessing_context'] = 'forkserver'
278
279 # 若无测试集,训练集做验证集
280 if not os.path.exists(args.test_dir) or not os.listdir(args.test_dir):
281 args.test_dir = args.train_dir
282 # 将数据进行转化,从PIL.Image/numpy.ndarray的数据进转化为torch.FloadTensor
283 trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
284 train_set = MNIST(root=args.train_dir, train=True, transform=trans)
285 test_set = MNIST(root=args.test_dir, train=False, transform=trans)
286 # Horovod: use DistributedSampler to partition the training data.
287 train_sampler = torch.utils.data.distributed.DistributedSampler(
288 train_set, num_replicas=hvd.size(), rank=hvd.rank())
289 # Horovod: use DistributedSampler to partition the test data.
290 test_sampler = torch.utils.data.distributed.DistributedSampler(
291 test_set, num_replicas=hvd.size(), rank=hvd.rank())
292 # 定义data reader
293 train_loader = torch.utils.data.DataLoader(
294 dataset=train_set,
295 batch_size=args.batch_size,
296 sampler=train_sampler,
297 **kwargs)
298 test_loader = torch.utils.data.DataLoader(
299 dataset=test_set,
300 batch_size=args.test_batch_size,
301 sampler=test_sampler,
302 **kwargs)
303 model = Net()
304 # By default, Adasum doesn't need scaling up learning rate.
305 lr_scaler = hvd.size() if not args.use_adasum else 1
306 if args.cuda:
307 # Move model to GPU.
308 model.cuda()
309 # If using GPU Adasum allreduce, scale learning rate by local_size.
310 if args.use_adasum and hvd.nccl_built():
311 lr_scaler = hvd.local_size()
312 # Horovod: scale learning rate by lr_scaler.
313 optimizer = optim.SGD(model.parameters(), lr=args.lr * lr_scaler,
314 momentum=args.momentum)
315 # Horovod: broadcast parameters & optimizer state.
316 hvd.broadcast_parameters(model.state_dict(), root_rank=0)
317 hvd.broadcast_optimizer_state(optimizer, root_rank=0)
318 # Horovod: (optional) compression algorithm.
319 compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
320 # Horovod: wrap optimizer with DistributedOptimizer.
321 optimizer = hvd.DistributedOptimizer(optimizer,
322 named_parameters=model.named_parameters(),
323 compression=compression,
324 op=hvd.Adasum if args.use_adasum else hvd.Average,
325 gradient_predivide_factor=args.gradient_predivide_factor)
326 for epoch in range(1, args.epochs + 1):
327 train(epoch)
328 test()
329 save()
推理代码
Pytorch模型在发布到模型仓库时,需要上传用于启动服务的自定义代码,并且在主文件名指定的py模块中实现:模型加载『model_fn』、请求预处理『input_fn』和预测结果后处理『output_fn』函数。
示例代码:
Python
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4@license: Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved.
5@desc: 图像预测算法示例
6"""
7import logging
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11import base64
12import json
13from PIL import Image
14from io import BytesIO
15from torchvision import datasets, models, transforms
16MODEL_FILE_NAME = 'model.pkl' # 模型文件名称
17def get_image_transform():
18 """获取图片处理的transform
19 Args:
20 data_type: string, type of data(train/test)
21 Returns:
22 torchvision.transforms.Compose
23 """
24 trans = transforms.Compose([transforms.Resize((28, 28)),
25 transforms.ToTensor(),
26 transforms.Normalize((0.5,), (1.0,))])
27 return trans
28def model_fn(model_dir):
29 """模型加载
30 Args:
31 model_dir: 模型路径,该目录存储的文件为在自定义作业选择的输出路径下产出的文件
32 Returns:
33 加载好的模型对象
34 """
35 class Net(nn.Module):
36 """
37 Net
38 """
39 def __init__(self):
40 super(Net, self).__init__()
41 self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
42 self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
43 self.conv2_drop = nn.Dropout2d()
44 self.fc1 = nn.Linear(320, 50)
45 self.fc2 = nn.Linear(50, 10)
46 def forward(self, x):
47 """
48 forward
49 """
50 x = F.relu(F.max_pool2d(self.conv1(x), 2))
51 x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
52 x = x.view(-1, 320)
53 x = F.relu(self.fc1(x))
54 x = F.dropout(x, training=self.training)
55 x = self.fc2(x)
56 return F.log_softmax(x)
57 model = Net()
58 meta_info_path = "%s/%s" % (model_dir, MODEL_FILE_NAME)
59 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60 model.load_state_dict(torch.load(meta_info_path, map_location=device))
61 model.to(device)
62 logging.info("device type: " + str(device))
63 return model
64def input_fn(request):
65 """对输入进行格式化,处理为预测需要的输入格式
66 Args:
67 request: api请求的json
68 Returns:
69 预测需要的输入数据,一般为tensor
70 """
71 instances = request['instances']
72 transform_composes = get_image_transform()
73 arr_tensor_data = []
74 for instance in instances:
75 decoded_data = base64.b64decode(instance['data'].encode("utf8"))
76 byte_stream = BytesIO(decoded_data)
77 roiImg = Image.open(byte_stream)
78 target_data = transform_composes(roiImg)
79 arr_tensor_data.append(target_data)
80 tensor_data = torch.stack(arr_tensor_data, dim=0)
81 return tensor_data
82def output_fn(predict_result):
83 """进行输出格式化
84 Args:
85 predict_result: 预测结果
86 Returns:
87 格式化后的预测结果,需能够json序列化以便接口返回
88 """
89 js_str = None
90 if type(predict_result) == torch.Tensor:
91 list_prediction = predict_result.detach().cpu().numpy().tolist()
92 js_str = json.dumps(list_prediction)
93 return js_str