Commit 7d06ba8c authored by Liuyuxinict's avatar Liuyuxinict

tijiao

parent 3b197eb1
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="1de14600-5bec-46d2-972f-11687490a303" name="Default Changelist" comment="" /> <list default="true" id="1de14600-5bec-46d2-972f-11687490a303" name="Default Changelist" comment="">
<change afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
<change afterPath="$PROJECT_DIR$/data_loader.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/train.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" /> <option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" /> <option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" /> <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
...@@ -14,7 +19,11 @@ ...@@ -14,7 +19,11 @@
</list> </list>
</option> </option>
</component> </component>
<component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component>
<component name="ProjectId" id="1hp1mgftlnNubdit1AM27vnxPWs" /> <component name="ProjectId" id="1hp1mgftlnNubdit1AM27vnxPWs" />
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
<component name="ProjectViewState"> <component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" /> <option name="hideEmptyMiddlePackages" value="true" />
<option name="showExcludedFiles" value="true" /> <option name="showExcludedFiles" value="true" />
...@@ -22,6 +31,7 @@ ...@@ -22,6 +31,7 @@
</component> </component>
<component name="PropertiesComponent"> <component name="PropertiesComponent">
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" /> <property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
<property name="SHARE_PROJECT_CONFIGURATION_FILES" value="true" />
<property name="last_opened_file_path" value="$PROJECT_DIR$" /> <property name="last_opened_file_path" value="$PROJECT_DIR$" />
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" /> <property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
</component> </component>
...@@ -120,22 +130,22 @@ ...@@ -120,22 +130,22 @@
<screen x="0" y="0" width="1536" height="824" /> <screen x="0" y="0" width="1536" height="824" />
</state> </state>
<state x="549" y="171" key="FileChooserDialogImpl/0.0.1536.824@0.0.1536.824" timestamp="1658372379078" /> <state x="549" y="171" key="FileChooserDialogImpl/0.0.1536.824@0.0.1536.824" timestamp="1658372379078" />
<state width="1515" height="290" key="GridCell.Tab.0.bottom" timestamp="1658494944285"> <state width="1515" height="223" key="GridCell.Tab.0.bottom" timestamp="1658717893280">
<screen x="0" y="0" width="1536" height="824" /> <screen x="0" y="0" width="1536" height="824" />
</state> </state>
<state width="1515" height="290" key="GridCell.Tab.0.bottom/0.0.1536.824@0.0.1536.824" timestamp="1658494944285" /> <state width="1515" height="223" key="GridCell.Tab.0.bottom/0.0.1536.824@0.0.1536.824" timestamp="1658717893280" />
<state width="1515" height="290" key="GridCell.Tab.0.center" timestamp="1658494944285"> <state width="1515" height="223" key="GridCell.Tab.0.center" timestamp="1658717893280">
<screen x="0" y="0" width="1536" height="824" /> <screen x="0" y="0" width="1536" height="824" />
</state> </state>
<state width="1515" height="290" key="GridCell.Tab.0.center/0.0.1536.824@0.0.1536.824" timestamp="1658494944285" /> <state width="1515" height="223" key="GridCell.Tab.0.center/0.0.1536.824@0.0.1536.824" timestamp="1658717893280" />
<state width="1515" height="290" key="GridCell.Tab.0.left" timestamp="1658494944285"> <state width="1515" height="223" key="GridCell.Tab.0.left" timestamp="1658717893279">
<screen x="0" y="0" width="1536" height="824" /> <screen x="0" y="0" width="1536" height="824" />
</state> </state>
<state width="1515" height="290" key="GridCell.Tab.0.left/0.0.1536.824@0.0.1536.824" timestamp="1658494944285" /> <state width="1515" height="223" key="GridCell.Tab.0.left/0.0.1536.824@0.0.1536.824" timestamp="1658717893279" />
<state width="1515" height="290" key="GridCell.Tab.0.right" timestamp="1658494944285"> <state width="1515" height="223" key="GridCell.Tab.0.right" timestamp="1658717893280">
<screen x="0" y="0" width="1536" height="824" /> <screen x="0" y="0" width="1536" height="824" />
</state> </state>
<state width="1515" height="290" key="GridCell.Tab.0.right/0.0.1536.824@0.0.1536.824" timestamp="1658494944285" /> <state width="1515" height="223" key="GridCell.Tab.0.right/0.0.1536.824@0.0.1536.824" timestamp="1658717893280" />
<state width="1515" height="290" key="GridCell.Tab.1.bottom" timestamp="1658494944285"> <state width="1515" height="290" key="GridCell.Tab.1.bottom" timestamp="1658494944285">
<screen x="0" y="0" width="1536" height="824" /> <screen x="0" y="0" width="1536" height="824" />
</state> </state>
......
import torch
import PIL
from PIL import Image
import torch.utils.data as data
from torchvision import datasets, transforms
def My_loader(path):
return PIL.Image.open(path).convert('RGB')
class MyDataset(torch.utils.data.Dataset):
def __init__(self, txt_dir, image_path, transform=None, target_transform=None, loader=My_loader):
data_txt = open(txt_dir, 'r')
imgs = []
for line in data_txt:
line = line.strip()
words = line.split(' ')
imgs.append((words[0], int(words[1].strip())))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = My_loader
self.image_path = image_path
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
img_name, label = self.imgs[index]
# label = list(map(int, label))
# print label
# print type(label)
#img = self.loader('/home/vipl/llh/food101_finetuning/food101_vgg/origal_data/images/'+img_name.replace("\\","/"))
img = self.loader(self.image_path + img_name)
# print img
if self.transform is not None:
img = self.transform(img)
# print img.size()
# label =torch.Tensor(label)
# print label.size()
return img, label
# if the label is the single-label it can be the int
# if the multilabel can be the list to torch.tensor
def load_data(image_path, train_dir, test_dir, batch_size):
normalize = transforms.Normalize(mean=[0.5457954, 0.44430383, 0.34424934],
std=[0.23273608, 0.24383051, 0.24237761])
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5), # default value is 0.5
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.126, saturation=0.5),
transforms.Resize((550, 550)),
transforms.RandomCrop(448),
transforms.ToTensor(),
normalize
])
# transforms of test dataset
test_transforms = transforms.Compose([
transforms.Resize((550, 550)),
transforms.CenterCrop((448, 448)),
transforms.ToTensor(),
normalize
])
train_dataset = MyDataset(txt_dir=train_dir, image_path=image_path, transform=train_transforms)
test_dataset = MyDataset(txt_dir=test_dir, image_path=image_path, transform=test_transforms)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size//2, shuffle=False, num_workers=0)
return train_dataset, train_loader, test_dataset, test_loader
from __future__ import print_function from __future__ import print_function
import os
from PIL import Image from PIL import Image
import torch.utils.data as data import torch.utils.data as data
import os import os
import PIL import PIL
import argparse
from tqdm import tqdm from tqdm import tqdm
import torch.optim as optim import torch.optim as optim
from data_loader import load_data
from torch.optim import lr_scheduler from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import re import re
from utils import * from utils import *
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
EPOCH = 200 # number of times for each run-through
BATCH_SIZE = 2 # number of images for each epoch def parse_option():
LEARNING_RATE = 0.0001 # default learning rate parser = argparse.ArgumentParser('Progressive Region Enhancement Network(PRENet) for training and testing')
GPU_IN_USE = torch.cuda.is_available() # whether using GPU
DIR_TRAIN_IMAGES = r'E:\datasets\food101\meta_data\train_full.txt' parser.add_argument('--batchsize', default=2, type=int, help="batch size for single GPU")
#DIR_TRAIN_IMAGES = "/home/vipl/lyx/train_full.txt" parser.add_argument('--dataset', type=str, default='food101', help='food2k, food101, food500')
DIR_TEST_IMAGES = r'E:\datasets\food101\meta_data\test_full.txt' parser.add_argument('--image_path', type=str, default="E:/datasets/food101/images/", help='path to dataset')
#DIR_TEST_IMAGES = "/home/vipl/lyx/test_full.txt" parser.add_argument("--train_path", type=str, default="E:/datasets/food101/meta_data/train_full.txt", help='path to training list')
Image_path = r"E:/datasets/food101/images/" parser.add_argument("--test_path", type=str, default="E:/datasets/food101/meta_data/test_full.txt",
#Image_path = "/home/vipl/lizhuo/dataset_food/food101/images/" help='path to testing list')
#NUM_CATEGORIES = 500 parser.add_argument('--weight_path', default="E:/Pretrained_model/food2k_resnet50_0.0001.pth", help='path to the pretrained model')
NUM_CATEGORIES = 101 parser.add_argument('--use_checkpoint', action='store_true', default=False,
#WEIGHT_PATH= '/home/vipl/lyx/resnet50.pth' help="whether to use gradient checkpointing to save memory")
WEIGHT_PATH = r'E:/Pretrained_model/food2k_resnet50_0.0001.pth' parser.add_argument('--checkpoint', type=str, default=None,
help="the path to checkpoint")
checkpoint = '' parser.add_argument('--output_dir', default='output', type=str, metavar='PATH',
useJP = False #use Jigsaw Patches during PMG food2k_448_from2k_only_cengnei help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
usecheckpoint = False parser.add_argument("--learning_rate", default=1e-4, type=float,
checkpath = "./food2k_448_from2k_only_cengnei/model.pth" help="The initial learning rate for SGD.")
parser.add_argument("--epoch", default=200, type=int,
useAttn = True help="The number of epochs.")
args, unparsed = parser.parse_known_args()
normalize = transforms.Normalize(mean=[0.5457954,0.44430383,0.34424934], return args
std=[0.23273608,0.24383051,0.24237761])
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5), # default value is 0.5
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.126,saturation=0.5),
transforms.Resize((550, 550)),
transforms.RandomCrop(448),
transforms.ToTensor(),
normalize
])
# transforms of test dataset
test_transforms = transforms.Compose([
transforms.Resize((550, 550)),
transforms.CenterCrop((448,448)),
transforms.ToTensor(),
normalize
])
def My_loader(path):
return PIL.Image.open(path).convert('RGB')
class MyDataset(torch.utils.data.Dataset):
def __init__(self, txt_dir, transform=None, target_transform=None, loader=My_loader):
data_txt = open(txt_dir, 'r')
imgs = []
for line in data_txt:
line = line.strip()
words = line.split(' ')
imgs.append((words[0], int(words[1].strip())))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = My_loader
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
img_name, label = self.imgs[index]
# label = list(map(int, label))
# print label
# print type(label)
#img = self.loader('/home/vipl/llh/food101_finetuning/food101_vgg/origal_data/images/'+img_name.replace("\\","/"))
img = self.loader(Image_path + img_name)
# print img
if self.transform is not None:
img = self.transform(img)
# print img.size()
# label =torch.Tensor(label)
# print label.size()
return img, label
# if the label is the single-label it can be the int
# if the multilabel can be the list to torch.tensor
train_dataset = MyDataset(txt_dir=DIR_TRAIN_IMAGES , transform=train_transforms)
test_dataset = MyDataset(txt_dir=DIR_TEST_IMAGES , transform=test_transforms)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE//2, shuffle=False, num_workers=0)
print('Data Preparation : Finished')
def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch, net,optimizer,exp_lr_scheduler): def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch, net,optimizer,exp_lr_scheduler):
exp_dir = store_name exp_dir = store_name
...@@ -140,10 +77,6 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch ...@@ -140,10 +77,6 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
# Step 1 # Step 1
optimizer.zero_grad() optimizer.zero_grad()
#inputs1 = jigsaw_generator(inputs, 8) #inputs1 = jigsaw_generator(inputs, 8)
inputs1 = None
if useJP:
_, _, _, _, output_1, _, _ = net(inputs1,False)
else:
_, _, _, _, output_1, _, _ = net(inputs, False) _, _, _, _, output_1, _, _ = net(inputs, False)
#print(output_1.shape) #print(output_1.shape)
loss1 = CELoss(output_1, targets) * 1 loss1 = CELoss(output_1, targets) * 1
...@@ -153,10 +86,7 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch ...@@ -153,10 +86,7 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
# Step 2 # Step 2
optimizer.zero_grad() optimizer.zero_grad()
#inputs2 = jigsaw_generator(inputs, 4) #inputs2 = jigsaw_generator(inputs, 4)
inputs2 = None
if useJP:
_, _, _, _, _, output_2, _, = net(inputs2,False)
else:
_, _, _, _, _, output_2, _, = net(inputs, False) _, _, _, _, _, output_2, _, = net(inputs, False)
#print(output_2.shape) #print(output_2.shape)
loss2 = CELoss(output_2, targets) * 1 loss2 = CELoss(output_2, targets) * 1
...@@ -166,10 +96,6 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch ...@@ -166,10 +96,6 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
# Step 3 # Step 3
optimizer.zero_grad() optimizer.zero_grad()
#inputs3 = jigsaw_generator(inputs, 2) #inputs3 = jigsaw_generator(inputs, 2)
inputs3 = None
if useJP:
_, _, _, _, _, _, output_3 = net(inputs3,False)
else:
_, _, _, _, _, _, output_3 = net(inputs, False) _, _, _, _, _, _, output_3 = net(inputs, False)
#print(output_3.shape) #print(output_3.shape)
loss3 = CELoss(output_3, targets) * 1 loss3 = CELoss(output_3, targets) * 1
...@@ -178,7 +104,7 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch ...@@ -178,7 +104,7 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
optimizer.zero_grad() optimizer.zero_grad()
x1, x2, x3, output_concat, _, _, _ = net(inputs,useAttn) x1, x2, x3, output_concat, _, _, _ = net(inputs,True)
concat_loss = CELoss(output_concat, targets) * 2 concat_loss = CELoss(output_concat, targets) * 2
...@@ -233,12 +159,25 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch ...@@ -233,12 +159,25 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
'Iteration %d, top1 = %.5f, top5 = %.5f, top1_combined = %.5f, top5_combined = %.5f, test_loss = %.6f\n' % ( 'Iteration %d, top1 = %.5f, top5 = %.5f, top1_combined = %.5f, top5_combined = %.5f, test_loss = %.6f\n' % (
epoch, val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss)) epoch, val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss))
net = load_model('resnet50_pmg',pretrain=False,require_grad=True,num_class=NUM_CATEGORIES) def main():
net.fc = nn.Linear(2048, 2000) args = parse_option()
state_dict = {} train_dataset, train_loader, test_dataset, test_loader = \
pretrained = torch.load(WEIGHT_PATH) load_data(image_path=args.image_path, train_dir=args.train_path, test_dir=args.test_path,batch_size=args.batchsize)
print('Data Preparation : Finished')
for k, v in net.state_dict().items(): if args.dataset == "food101":
NUM_CATEGORIES = 101
elif args.dataset == "food500":
NUM_CATEGORIES = 500
elif args.dataset == "food2k":
NUM_CATEGORIES = 2000
net = load_model('resnet50_pmg',pretrain=False,require_grad=True,num_class=NUM_CATEGORIES)
net.fc = nn.Linear(2048, 2000)
state_dict = {}
pretrained = torch.load(args.weight_path)
for k, v in net.state_dict().items():
if k[9:] in pretrained.keys() and "fc" not in k: if k[9:] in pretrained.keys() and "fc" not in k:
state_dict[k] = pretrained[k[9:]] state_dict[k] = pretrained[k[9:]]
elif "xx" in k and re.sub(r'xx[0-9]\.?',".", k[9:]) in pretrained.keys(): elif "xx" in k and re.sub(r'xx[0-9]\.?',".", k[9:]) in pretrained.keys():
...@@ -247,19 +186,19 @@ for k, v in net.state_dict().items(): ...@@ -247,19 +186,19 @@ for k, v in net.state_dict().items():
state_dict[k] = v state_dict[k] = v
print(k) print(k)
net.load_state_dict(state_dict) net.load_state_dict(state_dict)
net.fc = nn.Linear(2048,NUM_CATEGORIES) net.fc = nn.Linear(2048, NUM_CATEGORIES)
ignored_params = list(map(id, net.features.parameters())) ignored_params = list(map(id, net.features.parameters()))
new_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) new_params = filter(lambda p: id(p) not in ignored_params, net.parameters())
optimizer = optim.SGD([ optimizer = optim.SGD([
{'params': net.features.parameters(), 'lr': LEARNING_RATE*0.1}, {'params': net.features.parameters(), 'lr': args.learning_rate*0.1},
{'params': new_params, 'lr': LEARNING_RATE} {'params': new_params, 'lr': args.learning_rate}
], ],
momentum=0.9, weight_decay=5e-4) momentum=0.9, weight_decay=5e-4)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)
for p in optimizer.param_groups: for p in optimizer.param_groups:
outputs = '' outputs = ''
for k, v in p.items(): for k, v in p.items():
if k is 'params': if k is 'params':
...@@ -268,23 +207,25 @@ for p in optimizer.param_groups: ...@@ -268,23 +207,25 @@ for p in optimizer.param_groups:
outputs += (k + ': ' + str(v).ljust(10) + ' ') outputs += (k + ': ' + str(v).ljust(10) + ' ')
print(outputs) print(outputs)
cudnn.benchmark = True cudnn.benchmark = True
net.cuda() net.cuda()
net = nn.DataParallel(net) net = nn.DataParallel(net)
if usecheckpoint: if args.use_checkpoint:
#net.load_state_dict(torch.load(checkpath)) #net.load_state_dict(torch.load(checkpath))
net.module.load_state_dict(torch.load(checkpath).module.state_dict()) net.module.load_state_dict(torch.load(args.checkpoint).module.state_dict())
print('load the checkpoint') print('load the checkpoint')
train(nb_epoch=200, # number of epoch train(nb_epoch=args.epoch, # number of epoch
trainloader=train_loader, trainloader=train_loader,
testloader=test_loader, testloader=test_loader,
batch_size=BATCH_SIZE, # batch size batch_size=args.batchsize, # batch size
store_name='food2k_448_from2k_only_cengnei', # folder for output store_name='model_448_from2k', # folder for output
start_epoch=0, start_epoch=0,
net=net, net=net,
optimizer = optimizer, optimizer = optimizer,
exp_lr_scheduler=exp_lr_scheduler) # the start epoch number when you resume the training exp_lr_scheduler=exp_lr_scheduler) # the start epoch number when you resume the training
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment