Commit 6cb1b7ad authored by lishen's avatar lishen

[fix]

parent 7a41852c
......@@ -12,7 +12,8 @@ import torch.backends.cudnn as cudnn
import re
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def parse_option():
parser = argparse.ArgumentParser('Progressive Region Enhancement Network(PRENet) for training and testing')
......@@ -39,20 +40,19 @@ def parse_option():
args, unparsed = parser.parse_known_args()
return args
def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch, net,optimizer,exp_lr_scheduler):
def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch, net, optimizer, exp_lr_scheduler):
exp_dir = store_name
try:
os.stat(exp_dir)
except:
os.makedirs(exp_dir)
CELoss = nn.CrossEntropyLoss()
KLLoss = nn.KLDivLoss(reduction="batchmean")
max_val_acc = 0
#val_acc, val5_acc, _, _, val_loss = test(net, CELoss, batch_size, testloader)
# val_acc, val5_acc, _, _, val_loss = test(net, CELoss, batch_size, testloader)
for epoch in range(start_epoch, nb_epoch):
......@@ -78,42 +78,40 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
# Step 1
optimizer.zero_grad()
#inputs1 = jigsaw_generator(inputs, 8)
# inputs1 = jigsaw_generator(inputs, 8)
_, _, _, _, output_1, _, _ = net(inputs, False)
#print(output_1.shape)
# print(output_1.shape)
loss1 = CELoss(output_1, targets) * 1
loss1.backward()
optimizer.step()
# Step 2
optimizer.zero_grad()
#inputs2 = jigsaw_generator(inputs, 4)
# inputs2 = jigsaw_generator(inputs, 4)
_, _, _, _, _, output_2, _, = net(inputs, False)
#print(output_2.shape)
# print(output_2.shape)
loss2 = CELoss(output_2, targets) * 1
loss2.backward()
optimizer.step()
# Step 3
optimizer.zero_grad()
#inputs3 = jigsaw_generator(inputs, 2)
# inputs3 = jigsaw_generator(inputs, 2)
_, _, _, _, _, _, output_3 = net(inputs, False)
#print(output_3.shape)
# print(output_3.shape)
loss3 = CELoss(output_3, targets) * 1
loss3.backward()
optimizer.step()
optimizer.zero_grad()
x1, x2, x3, output_concat, _, _, _ = net(inputs,True)
x1, x2, x3, output_concat, _, _, _ = net(inputs, True)
concat_loss = CELoss(output_concat, targets) * 2
#loss4 = -KLLoss(F.softmax(x1, dim=1), F.softmax(x2, dim=1)) / batch_size
#loss5 = -KLLoss(F.softmax(x1, dim=1), F.softmax(x3, dim=1)) / batch_size
# loss4 = -KLLoss(F.softmax(x1, dim=1), F.softmax(x2, dim=1)) / batch_size
# loss5 = -KLLoss(F.softmax(x1, dim=1), F.softmax(x3, dim=1)) / batch_size
loss6 = -KLLoss(F.softmax(x2, dim=1), F.softmax(x1, dim=1))
#loss7 = -KLLoss(F.softmax(x2, dim=1), F.softmax(x3, dim=1)) / batch_size
# loss7 = -KLLoss(F.softmax(x2, dim=1), F.softmax(x3, dim=1)) / batch_size
loss8 = -KLLoss(F.softmax(x3, dim=1), F.softmax(x1, dim=1))
loss9 = -KLLoss(F.softmax(x3, dim=1), F.softmax(x2, dim=1))
......@@ -137,9 +135,9 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
if batch_idx % 10 == 0:
print(
'Step: %d | Loss1: %.3f | Loss2: %.5f | Loss3: %.5f | Loss_concat: %.5f | Loss: %.3f | Acc: %.3f%% (%d/%d)' % (
batch_idx, train_loss1 / (batch_idx + 1), train_loss2 / (batch_idx + 1),
train_loss3 / (batch_idx + 1), train_loss4 / (batch_idx + 1), train_loss / (batch_idx + 1),
100. * float(correct) / total, correct, total))
batch_idx, train_loss1 / (batch_idx + 1), train_loss2 / (batch_idx + 1),
train_loss3 / (batch_idx + 1), train_loss4 / (batch_idx + 1), train_loss / (batch_idx + 1),
100. * float(correct) / total, correct, total))
batch_idx += 1
exp_lr_scheduler.step()
......@@ -149,10 +147,10 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
with open(exp_dir + '/results_train.txt', 'a') as file:
file.write(
'Iteration %d | train_acc = %.5f | train_loss = %.5f | Loss1: %.3f | Loss2: %.5f | Loss3: %.5f | Loss_concat: %.5f |\n' % (
epoch, train_acc, train_loss, train_loss1 / (idx + 1), train_loss2 / (idx + 1), train_loss3 / (idx + 1),
train_loss4 / (idx + 1)))
epoch, train_acc, train_loss, train_loss1 / (idx + 1), train_loss2 / (idx + 1), train_loss3 / (idx + 1),
train_loss4 / (idx + 1)))
val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss = test(net, CELoss, batch_size, testloader,True)
val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss = test(net, CELoss, batch_size, testloader, True)
if val_acc > max_val_acc:
max_val_acc = val_acc
torch.save(net, './' + store_name + '/model.pth')
......@@ -161,10 +159,11 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
'Iteration %d, top1 = %.5f, top5 = %.5f, top1_combined = %.5f, top5_combined = %.5f, test_loss = %.6f\n' % (
epoch, val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss))
def main():
args = parse_option()
train_dataset, train_loader, test_dataset, test_loader = \
load_data(image_path=args.image_path, train_dir=args.train_path, test_dir=args.test_path,batch_size=args.batchsize)
load_data(image_path=args.image_path, train_dir=args.train_path, test_dir=args.test_path, batch_size=args.batchsize)
print('Data Preparation : Finished')
if args.dataset == "food101":
NUM_CATEGORIES = 101
......@@ -175,8 +174,7 @@ def main():
elif args.dataset == "jkyy":
NUM_CATEGORIES = 1788
net = load_model('resnet50',pretrain=False,require_grad=True,num_class=NUM_CATEGORIES)
net = load_model('resnet50', pretrain=False, require_grad=True, num_class=NUM_CATEGORIES)
net.fc = nn.Linear(2048, 2000)
state_dict = {}
pretrained = torch.load(args.weight_path)
......@@ -184,8 +182,8 @@ def main():
for k, v in net.state_dict().items():
if k[9:] in pretrained.keys() and "fc" not in k:
state_dict[k] = pretrained[k[9:]]
elif "xx" in k and re.sub(r'xx[0-9]\.?',".", k[9:]) in pretrained.keys():
state_dict[k] = pretrained[re.sub(r'xx[0-9]\.?',".", k[9:])]
elif "xx" in k and re.sub(r'xx[0-9]\.?', ".", k[9:]) in pretrained.keys():
state_dict[k] = pretrained[re.sub(r'xx[0-9]\.?', ".", k[9:])]
else:
state_dict[k] = v
print(k)
......@@ -196,7 +194,7 @@ def main():
ignored_params = list(map(id, net.features.parameters()))
new_params = filter(lambda p: id(p) not in ignored_params, net.parameters())
optimizer = optim.SGD([
{'params': net.features.parameters(), 'lr': args.learning_rate*0.1},
{'params': net.features.parameters(), 'lr': args.learning_rate * 0.1},
{'params': new_params, 'lr': args.learning_rate}
],
momentum=0.9, weight_decay=5e-4)
......@@ -213,13 +211,13 @@ def main():
cudnn.benchmark = True
net.cuda()
device_ids = [0]
device_ids = [0, 1, 2, 3, 4]
# net = nn.DataParallel(net).to(device_ids)
net = nn.DataParallel(net, device_ids=device_ids)
# optimizer = nn.DataParallel(optimizer, device_ids=device_ids)
if args.use_checkpoint:
#net.load_state_dict(torch.load(checkpath))
# net.load_state_dict(torch.load(checkpath))
model = torch.load(args.checkpoint).module.state_dict()
net.module.load_state_dict(torch.load(args.checkpoint).module.state_dict())
......@@ -228,19 +226,19 @@ def main():
if args.test:
val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss = test(net, nn.CrossEntropyLoss(), args.batchsize, test_loader, True)
print('Accuracy of the network on the val images: top1 = %.5f, top5 = %.5f, top1_combined = %.5f, top5_combined = %.5f, test_loss = %.6f\n' % (
val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss))
val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss))
return
train(nb_epoch=args.epoch, # number of epoch
trainloader=train_loader,
testloader=test_loader,
batch_size=args.batchsize, # batch size
store_name='model_448_from2k', # folder for output
start_epoch=0,
net=net,
optimizer=optimizer,
exp_lr_scheduler=exp_lr_scheduler) # the start epoch number when you resume the training
train(nb_epoch=args.epoch, # number of epoch
trainloader=train_loader,
testloader=test_loader,
batch_size=args.batchsize, # batch size
store_name='model_448_from2k', # folder for output
start_epoch=0,
net=net,
optimizer = optimizer,
exp_lr_scheduler=exp_lr_scheduler) # the start epoch number when you resume the training
if __name__ == "__main__":
main()
\ No newline at end of file
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment