Commit 6cb1b7ad authored by lishen's avatar lishen

[fix]

parent 7a41852c
...@@ -12,7 +12,8 @@ import torch.backends.cudnn as cudnn ...@@ -12,7 +12,8 @@ import torch.backends.cudnn as cudnn
import re import re
from utils import * from utils import *
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def parse_option(): def parse_option():
parser = argparse.ArgumentParser('Progressive Region Enhancement Network(PRENet) for training and testing') parser = argparse.ArgumentParser('Progressive Region Enhancement Network(PRENet) for training and testing')
...@@ -39,20 +40,19 @@ def parse_option(): ...@@ -39,20 +40,19 @@ def parse_option():
args, unparsed = parser.parse_known_args() args, unparsed = parser.parse_known_args()
return args return args
def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch, net,optimizer,exp_lr_scheduler):
def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch, net, optimizer, exp_lr_scheduler):
exp_dir = store_name exp_dir = store_name
try: try:
os.stat(exp_dir) os.stat(exp_dir)
except: except:
os.makedirs(exp_dir) os.makedirs(exp_dir)
CELoss = nn.CrossEntropyLoss() CELoss = nn.CrossEntropyLoss()
KLLoss = nn.KLDivLoss(reduction="batchmean") KLLoss = nn.KLDivLoss(reduction="batchmean")
max_val_acc = 0 max_val_acc = 0
#val_acc, val5_acc, _, _, val_loss = test(net, CELoss, batch_size, testloader) # val_acc, val5_acc, _, _, val_loss = test(net, CELoss, batch_size, testloader)
for epoch in range(start_epoch, nb_epoch): for epoch in range(start_epoch, nb_epoch):
...@@ -78,42 +78,40 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch ...@@ -78,42 +78,40 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
# Step 1 # Step 1
optimizer.zero_grad() optimizer.zero_grad()
#inputs1 = jigsaw_generator(inputs, 8) # inputs1 = jigsaw_generator(inputs, 8)
_, _, _, _, output_1, _, _ = net(inputs, False) _, _, _, _, output_1, _, _ = net(inputs, False)
#print(output_1.shape) # print(output_1.shape)
loss1 = CELoss(output_1, targets) * 1 loss1 = CELoss(output_1, targets) * 1
loss1.backward() loss1.backward()
optimizer.step() optimizer.step()
# Step 2 # Step 2
optimizer.zero_grad() optimizer.zero_grad()
#inputs2 = jigsaw_generator(inputs, 4) # inputs2 = jigsaw_generator(inputs, 4)
_, _, _, _, _, output_2, _, = net(inputs, False) _, _, _, _, _, output_2, _, = net(inputs, False)
#print(output_2.shape) # print(output_2.shape)
loss2 = CELoss(output_2, targets) * 1 loss2 = CELoss(output_2, targets) * 1
loss2.backward() loss2.backward()
optimizer.step() optimizer.step()
# Step 3 # Step 3
optimizer.zero_grad() optimizer.zero_grad()
#inputs3 = jigsaw_generator(inputs, 2) # inputs3 = jigsaw_generator(inputs, 2)
_, _, _, _, _, _, output_3 = net(inputs, False) _, _, _, _, _, _, output_3 = net(inputs, False)
#print(output_3.shape) # print(output_3.shape)
loss3 = CELoss(output_3, targets) * 1 loss3 = CELoss(output_3, targets) * 1
loss3.backward() loss3.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
x1, x2, x3, output_concat, _, _, _ = net(inputs,True) x1, x2, x3, output_concat, _, _, _ = net(inputs, True)
concat_loss = CELoss(output_concat, targets) * 2 concat_loss = CELoss(output_concat, targets) * 2
# loss4 = -KLLoss(F.softmax(x1, dim=1), F.softmax(x2, dim=1)) / batch_size
#loss4 = -KLLoss(F.softmax(x1, dim=1), F.softmax(x2, dim=1)) / batch_size # loss5 = -KLLoss(F.softmax(x1, dim=1), F.softmax(x3, dim=1)) / batch_size
#loss5 = -KLLoss(F.softmax(x1, dim=1), F.softmax(x3, dim=1)) / batch_size
loss6 = -KLLoss(F.softmax(x2, dim=1), F.softmax(x1, dim=1)) loss6 = -KLLoss(F.softmax(x2, dim=1), F.softmax(x1, dim=1))
#loss7 = -KLLoss(F.softmax(x2, dim=1), F.softmax(x3, dim=1)) / batch_size # loss7 = -KLLoss(F.softmax(x2, dim=1), F.softmax(x3, dim=1)) / batch_size
loss8 = -KLLoss(F.softmax(x3, dim=1), F.softmax(x1, dim=1)) loss8 = -KLLoss(F.softmax(x3, dim=1), F.softmax(x1, dim=1))
loss9 = -KLLoss(F.softmax(x3, dim=1), F.softmax(x2, dim=1)) loss9 = -KLLoss(F.softmax(x3, dim=1), F.softmax(x2, dim=1))
...@@ -137,9 +135,9 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch ...@@ -137,9 +135,9 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
if batch_idx % 10 == 0: if batch_idx % 10 == 0:
print( print(
'Step: %d | Loss1: %.3f | Loss2: %.5f | Loss3: %.5f | Loss_concat: %.5f | Loss: %.3f | Acc: %.3f%% (%d/%d)' % ( 'Step: %d | Loss1: %.3f | Loss2: %.5f | Loss3: %.5f | Loss_concat: %.5f | Loss: %.3f | Acc: %.3f%% (%d/%d)' % (
batch_idx, train_loss1 / (batch_idx + 1), train_loss2 / (batch_idx + 1), batch_idx, train_loss1 / (batch_idx + 1), train_loss2 / (batch_idx + 1),
train_loss3 / (batch_idx + 1), train_loss4 / (batch_idx + 1), train_loss / (batch_idx + 1), train_loss3 / (batch_idx + 1), train_loss4 / (batch_idx + 1), train_loss / (batch_idx + 1),
100. * float(correct) / total, correct, total)) 100. * float(correct) / total, correct, total))
batch_idx += 1 batch_idx += 1
exp_lr_scheduler.step() exp_lr_scheduler.step()
...@@ -149,10 +147,10 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch ...@@ -149,10 +147,10 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
with open(exp_dir + '/results_train.txt', 'a') as file: with open(exp_dir + '/results_train.txt', 'a') as file:
file.write( file.write(
'Iteration %d | train_acc = %.5f | train_loss = %.5f | Loss1: %.3f | Loss2: %.5f | Loss3: %.5f | Loss_concat: %.5f |\n' % ( 'Iteration %d | train_acc = %.5f | train_loss = %.5f | Loss1: %.3f | Loss2: %.5f | Loss3: %.5f | Loss_concat: %.5f |\n' % (
epoch, train_acc, train_loss, train_loss1 / (idx + 1), train_loss2 / (idx + 1), train_loss3 / (idx + 1), epoch, train_acc, train_loss, train_loss1 / (idx + 1), train_loss2 / (idx + 1), train_loss3 / (idx + 1),
train_loss4 / (idx + 1))) train_loss4 / (idx + 1)))
val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss = test(net, CELoss, batch_size, testloader,True) val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss = test(net, CELoss, batch_size, testloader, True)
if val_acc > max_val_acc: if val_acc > max_val_acc:
max_val_acc = val_acc max_val_acc = val_acc
torch.save(net, './' + store_name + '/model.pth') torch.save(net, './' + store_name + '/model.pth')
...@@ -161,10 +159,11 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch ...@@ -161,10 +159,11 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
'Iteration %d, top1 = %.5f, top5 = %.5f, top1_combined = %.5f, top5_combined = %.5f, test_loss = %.6f\n' % ( 'Iteration %d, top1 = %.5f, top5 = %.5f, top1_combined = %.5f, top5_combined = %.5f, test_loss = %.6f\n' % (
epoch, val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss)) epoch, val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss))
def main(): def main():
args = parse_option() args = parse_option()
train_dataset, train_loader, test_dataset, test_loader = \ train_dataset, train_loader, test_dataset, test_loader = \
load_data(image_path=args.image_path, train_dir=args.train_path, test_dir=args.test_path,batch_size=args.batchsize) load_data(image_path=args.image_path, train_dir=args.train_path, test_dir=args.test_path, batch_size=args.batchsize)
print('Data Preparation : Finished') print('Data Preparation : Finished')
if args.dataset == "food101": if args.dataset == "food101":
NUM_CATEGORIES = 101 NUM_CATEGORIES = 101
...@@ -175,8 +174,7 @@ def main(): ...@@ -175,8 +174,7 @@ def main():
elif args.dataset == "jkyy": elif args.dataset == "jkyy":
NUM_CATEGORIES = 1788 NUM_CATEGORIES = 1788
net = load_model('resnet50', pretrain=False, require_grad=True, num_class=NUM_CATEGORIES)
net = load_model('resnet50',pretrain=False,require_grad=True,num_class=NUM_CATEGORIES)
net.fc = nn.Linear(2048, 2000) net.fc = nn.Linear(2048, 2000)
state_dict = {} state_dict = {}
pretrained = torch.load(args.weight_path) pretrained = torch.load(args.weight_path)
...@@ -184,8 +182,8 @@ def main(): ...@@ -184,8 +182,8 @@ def main():
for k, v in net.state_dict().items(): for k, v in net.state_dict().items():
if k[9:] in pretrained.keys() and "fc" not in k: if k[9:] in pretrained.keys() and "fc" not in k:
state_dict[k] = pretrained[k[9:]] state_dict[k] = pretrained[k[9:]]
elif "xx" in k and re.sub(r'xx[0-9]\.?',".", k[9:]) in pretrained.keys(): elif "xx" in k and re.sub(r'xx[0-9]\.?', ".", k[9:]) in pretrained.keys():
state_dict[k] = pretrained[re.sub(r'xx[0-9]\.?',".", k[9:])] state_dict[k] = pretrained[re.sub(r'xx[0-9]\.?', ".", k[9:])]
else: else:
state_dict[k] = v state_dict[k] = v
print(k) print(k)
...@@ -196,7 +194,7 @@ def main(): ...@@ -196,7 +194,7 @@ def main():
ignored_params = list(map(id, net.features.parameters())) ignored_params = list(map(id, net.features.parameters()))
new_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) new_params = filter(lambda p: id(p) not in ignored_params, net.parameters())
optimizer = optim.SGD([ optimizer = optim.SGD([
{'params': net.features.parameters(), 'lr': args.learning_rate*0.1}, {'params': net.features.parameters(), 'lr': args.learning_rate * 0.1},
{'params': new_params, 'lr': args.learning_rate} {'params': new_params, 'lr': args.learning_rate}
], ],
momentum=0.9, weight_decay=5e-4) momentum=0.9, weight_decay=5e-4)
...@@ -213,13 +211,13 @@ def main(): ...@@ -213,13 +211,13 @@ def main():
cudnn.benchmark = True cudnn.benchmark = True
net.cuda() net.cuda()
device_ids = [0] device_ids = [0, 1, 2, 3, 4]
# net = nn.DataParallel(net).to(device_ids) # net = nn.DataParallel(net).to(device_ids)
net = nn.DataParallel(net, device_ids=device_ids) net = nn.DataParallel(net, device_ids=device_ids)
# optimizer = nn.DataParallel(optimizer, device_ids=device_ids) # optimizer = nn.DataParallel(optimizer, device_ids=device_ids)
if args.use_checkpoint: if args.use_checkpoint:
#net.load_state_dict(torch.load(checkpath)) # net.load_state_dict(torch.load(checkpath))
model = torch.load(args.checkpoint).module.state_dict() model = torch.load(args.checkpoint).module.state_dict()
net.module.load_state_dict(torch.load(args.checkpoint).module.state_dict()) net.module.load_state_dict(torch.load(args.checkpoint).module.state_dict())
...@@ -228,19 +226,19 @@ def main(): ...@@ -228,19 +226,19 @@ def main():
if args.test: if args.test:
val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss = test(net, nn.CrossEntropyLoss(), args.batchsize, test_loader, True) val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss = test(net, nn.CrossEntropyLoss(), args.batchsize, test_loader, True)
print('Accuracy of the network on the val images: top1 = %.5f, top5 = %.5f, top1_combined = %.5f, top5_combined = %.5f, test_loss = %.6f\n' % ( print('Accuracy of the network on the val images: top1 = %.5f, top5 = %.5f, top1_combined = %.5f, top5_combined = %.5f, test_loss = %.6f\n' % (
val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss)) val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss))
return return
train(nb_epoch=args.epoch, # number of epoch
trainloader=train_loader,
testloader=test_loader,
batch_size=args.batchsize, # batch size
store_name='model_448_from2k', # folder for output
start_epoch=0,
net=net,
optimizer=optimizer,
exp_lr_scheduler=exp_lr_scheduler) # the start epoch number when you resume the training
train(nb_epoch=args.epoch, # number of epoch
trainloader=train_loader,
testloader=test_loader,
batch_size=args.batchsize, # batch size
store_name='model_448_from2k', # folder for output
start_epoch=0,
net=net,
optimizer = optimizer,
exp_lr_scheduler=exp_lr_scheduler) # the start epoch number when you resume the training
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment