Commit 0b240082 authored by lishen's avatar lishen

[fix]

parent 20813195
......@@ -15,6 +15,7 @@ class MyDataset(torch.utils.data.Dataset):
for line in data_txt:
line = line.strip()
words = line.split(' ')
print(words)
imgs.append((words[0], int(words[1].strip())))
self.imgs = imgs
self.transform = transform
......
......@@ -16,6 +16,11 @@ def load_model(model_name, pretrain=True, require_grad=True, num_class=1000, pre
#for param in net.parameters():
#param.requires_grad = require_grad
net = PRENet(net, 512, num_class)
elif model_name == 'resnet152':
net = resnet152(pretrained=pretrain, path=pretrained_model)
#for param in net.parameters():
#param.requires_grad = require_grad
net = PRENet(net, 512, num_class)
return net
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment