Commit 7d06ba8c authored by Liuyuxinict's avatar Liuyuxinict

tijiao

parent 3b197eb1
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ChangeListManager">
<list default="true" id="1de14600-5bec-46d2-972f-11687490a303" name="Default Changelist" comment="" />
<list default="true" id="1de14600-5bec-46d2-972f-11687490a303" name="Default Changelist" comment="">
<change afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
<change afterPath="$PROJECT_DIR$/data_loader.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/train.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
......@@ -14,7 +19,11 @@
</list>
</option>
</component>
<component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component>
<component name="ProjectId" id="1hp1mgftlnNubdit1AM27vnxPWs" />
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
<component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" />
<option name="showExcludedFiles" value="true" />
......@@ -22,6 +31,7 @@
</component>
<component name="PropertiesComponent">
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
<property name="SHARE_PROJECT_CONFIGURATION_FILES" value="true" />
<property name="last_opened_file_path" value="$PROJECT_DIR$" />
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
</component>
......@@ -120,22 +130,22 @@
<screen x="0" y="0" width="1536" height="824" />
</state>
<state x="549" y="171" key="FileChooserDialogImpl/0.0.1536.824@0.0.1536.824" timestamp="1658372379078" />
<state width="1515" height="290" key="GridCell.Tab.0.bottom" timestamp="1658494944285">
<state width="1515" height="223" key="GridCell.Tab.0.bottom" timestamp="1658717893280">
<screen x="0" y="0" width="1536" height="824" />
</state>
<state width="1515" height="290" key="GridCell.Tab.0.bottom/0.0.1536.824@0.0.1536.824" timestamp="1658494944285" />
<state width="1515" height="290" key="GridCell.Tab.0.center" timestamp="1658494944285">
<state width="1515" height="223" key="GridCell.Tab.0.bottom/0.0.1536.824@0.0.1536.824" timestamp="1658717893280" />
<state width="1515" height="223" key="GridCell.Tab.0.center" timestamp="1658717893280">
<screen x="0" y="0" width="1536" height="824" />
</state>
<state width="1515" height="290" key="GridCell.Tab.0.center/0.0.1536.824@0.0.1536.824" timestamp="1658494944285" />
<state width="1515" height="290" key="GridCell.Tab.0.left" timestamp="1658494944285">
<state width="1515" height="223" key="GridCell.Tab.0.center/0.0.1536.824@0.0.1536.824" timestamp="1658717893280" />
<state width="1515" height="223" key="GridCell.Tab.0.left" timestamp="1658717893279">
<screen x="0" y="0" width="1536" height="824" />
</state>
<state width="1515" height="290" key="GridCell.Tab.0.left/0.0.1536.824@0.0.1536.824" timestamp="1658494944285" />
<state width="1515" height="290" key="GridCell.Tab.0.right" timestamp="1658494944285">
<state width="1515" height="223" key="GridCell.Tab.0.left/0.0.1536.824@0.0.1536.824" timestamp="1658717893279" />
<state width="1515" height="223" key="GridCell.Tab.0.right" timestamp="1658717893280">
<screen x="0" y="0" width="1536" height="824" />
</state>
<state width="1515" height="290" key="GridCell.Tab.0.right/0.0.1536.824@0.0.1536.824" timestamp="1658494944285" />
<state width="1515" height="223" key="GridCell.Tab.0.right/0.0.1536.824@0.0.1536.824" timestamp="1658717893280" />
<state width="1515" height="290" key="GridCell.Tab.1.bottom" timestamp="1658494944285">
<screen x="0" y="0" width="1536" height="824" />
</state>
......
import torch
import PIL
from PIL import Image
import torch.utils.data as data
from torchvision import datasets, transforms
def My_loader(path):
return PIL.Image.open(path).convert('RGB')
class MyDataset(torch.utils.data.Dataset):
def __init__(self, txt_dir, image_path, transform=None, target_transform=None, loader=My_loader):
data_txt = open(txt_dir, 'r')
imgs = []
for line in data_txt:
line = line.strip()
words = line.split(' ')
imgs.append((words[0], int(words[1].strip())))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = My_loader
self.image_path = image_path
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
img_name, label = self.imgs[index]
# label = list(map(int, label))
# print label
# print type(label)
#img = self.loader('/home/vipl/llh/food101_finetuning/food101_vgg/origal_data/images/'+img_name.replace("\\","/"))
img = self.loader(self.image_path + img_name)
# print img
if self.transform is not None:
img = self.transform(img)
# print img.size()
# label =torch.Tensor(label)
# print label.size()
return img, label
# if the label is the single-label it can be the int
# if the multilabel can be the list to torch.tensor
def load_data(image_path, train_dir, test_dir, batch_size):
normalize = transforms.Normalize(mean=[0.5457954, 0.44430383, 0.34424934],
std=[0.23273608, 0.24383051, 0.24237761])
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5), # default value is 0.5
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.126, saturation=0.5),
transforms.Resize((550, 550)),
transforms.RandomCrop(448),
transforms.ToTensor(),
normalize
])
# transforms of test dataset
test_transforms = transforms.Compose([
transforms.Resize((550, 550)),
transforms.CenterCrop((448, 448)),
transforms.ToTensor(),
normalize
])
train_dataset = MyDataset(txt_dir=train_dir, image_path=image_path, transform=train_transforms)
test_dataset = MyDataset(txt_dir=test_dir, image_path=image_path, transform=test_transforms)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size//2, shuffle=False, num_workers=0)
return train_dataset, train_loader, test_dataset, test_loader
from __future__ import print_function
import os
from PIL import Image
import torch.utils.data as data
import os
import PIL
import argparse
from tqdm import tqdm
import torch.optim as optim
from data_loader import load_data
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import re
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
EPOCH = 200 # number of times for each run-through
BATCH_SIZE = 2 # number of images for each epoch
LEARNING_RATE = 0.0001 # default learning rate
GPU_IN_USE = torch.cuda.is_available() # whether using GPU
DIR_TRAIN_IMAGES = r'E:\datasets\food101\meta_data\train_full.txt'
#DIR_TRAIN_IMAGES = "/home/vipl/lyx/train_full.txt"
DIR_TEST_IMAGES = r'E:\datasets\food101\meta_data\test_full.txt'
#DIR_TEST_IMAGES = "/home/vipl/lyx/test_full.txt"
Image_path = r"E:/datasets/food101/images/"
#Image_path = "/home/vipl/lizhuo/dataset_food/food101/images/"
#NUM_CATEGORIES = 500
NUM_CATEGORIES = 101
#WEIGHT_PATH= '/home/vipl/lyx/resnet50.pth'
WEIGHT_PATH = r'E:/Pretrained_model/food2k_resnet50_0.0001.pth'
checkpoint = ''
useJP = False #use Jigsaw Patches during PMG food2k_448_from2k_only_cengnei
usecheckpoint = False
checkpath = "./food2k_448_from2k_only_cengnei/model.pth"
useAttn = True
normalize = transforms.Normalize(mean=[0.5457954,0.44430383,0.34424934],
std=[0.23273608,0.24383051,0.24237761])
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5), # default value is 0.5
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.126,saturation=0.5),
transforms.Resize((550, 550)),
transforms.RandomCrop(448),
transforms.ToTensor(),
normalize
])
# transforms of test dataset
test_transforms = transforms.Compose([
transforms.Resize((550, 550)),
transforms.CenterCrop((448,448)),
transforms.ToTensor(),
normalize
])
def My_loader(path):
return PIL.Image.open(path).convert('RGB')
class MyDataset(torch.utils.data.Dataset):
def __init__(self, txt_dir, transform=None, target_transform=None, loader=My_loader):
data_txt = open(txt_dir, 'r')
imgs = []
for line in data_txt:
line = line.strip()
words = line.split(' ')
imgs.append((words[0], int(words[1].strip())))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = My_loader
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
img_name, label = self.imgs[index]
# label = list(map(int, label))
# print label
# print type(label)
#img = self.loader('/home/vipl/llh/food101_finetuning/food101_vgg/origal_data/images/'+img_name.replace("\\","/"))
img = self.loader(Image_path + img_name)
# print img
if self.transform is not None:
img = self.transform(img)
# print img.size()
# label =torch.Tensor(label)
# print label.size()
return img, label
# if the label is the single-label it can be the int
# if the multilabel can be the list to torch.tensor
train_dataset = MyDataset(txt_dir=DIR_TRAIN_IMAGES , transform=train_transforms)
test_dataset = MyDataset(txt_dir=DIR_TEST_IMAGES , transform=test_transforms)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE//2, shuffle=False, num_workers=0)
print('Data Preparation : Finished')
def parse_option():
parser = argparse.ArgumentParser('Progressive Region Enhancement Network(PRENet) for training and testing')
parser.add_argument('--batchsize', default=2, type=int, help="batch size for single GPU")
parser.add_argument('--dataset', type=str, default='food101', help='food2k, food101, food500')
parser.add_argument('--image_path', type=str, default="E:/datasets/food101/images/", help='path to dataset')
parser.add_argument("--train_path", type=str, default="E:/datasets/food101/meta_data/train_full.txt", help='path to training list')
parser.add_argument("--test_path", type=str, default="E:/datasets/food101/meta_data/test_full.txt",
help='path to testing list')
parser.add_argument('--weight_path', default="E:/Pretrained_model/food2k_resnet50_0.0001.pth", help='path to the pretrained model')
parser.add_argument('--use_checkpoint', action='store_true', default=False,
help="whether to use gradient checkpointing to save memory")
parser.add_argument('--checkpoint', type=str, default=None,
help="the path to checkpoint")
parser.add_argument('--output_dir', default='output', type=str, metavar='PATH',
help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
parser.add_argument("--learning_rate", default=1e-4, type=float,
help="The initial learning rate for SGD.")
parser.add_argument("--epoch", default=200, type=int,
help="The number of epochs.")
args, unparsed = parser.parse_known_args()
return args
def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch, net,optimizer,exp_lr_scheduler):
exp_dir = store_name
......@@ -140,10 +77,6 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
# Step 1
optimizer.zero_grad()
#inputs1 = jigsaw_generator(inputs, 8)
inputs1 = None
if useJP:
_, _, _, _, output_1, _, _ = net(inputs1,False)
else:
_, _, _, _, output_1, _, _ = net(inputs, False)
#print(output_1.shape)
loss1 = CELoss(output_1, targets) * 1
......@@ -153,10 +86,7 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
# Step 2
optimizer.zero_grad()
#inputs2 = jigsaw_generator(inputs, 4)
inputs2 = None
if useJP:
_, _, _, _, _, output_2, _, = net(inputs2,False)
else:
_, _, _, _, _, output_2, _, = net(inputs, False)
#print(output_2.shape)
loss2 = CELoss(output_2, targets) * 1
......@@ -166,10 +96,6 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
# Step 3
optimizer.zero_grad()
#inputs3 = jigsaw_generator(inputs, 2)
inputs3 = None
if useJP:
_, _, _, _, _, _, output_3 = net(inputs3,False)
else:
_, _, _, _, _, _, output_3 = net(inputs, False)
#print(output_3.shape)
loss3 = CELoss(output_3, targets) * 1
......@@ -178,7 +104,7 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
optimizer.zero_grad()
x1, x2, x3, output_concat, _, _, _ = net(inputs,useAttn)
x1, x2, x3, output_concat, _, _, _ = net(inputs,True)
concat_loss = CELoss(output_concat, targets) * 2
......@@ -233,12 +159,25 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
'Iteration %d, top1 = %.5f, top5 = %.5f, top1_combined = %.5f, top5_combined = %.5f, test_loss = %.6f\n' % (
epoch, val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss))
net = load_model('resnet50_pmg',pretrain=False,require_grad=True,num_class=NUM_CATEGORIES)
net.fc = nn.Linear(2048, 2000)
state_dict = {}
pretrained = torch.load(WEIGHT_PATH)
for k, v in net.state_dict().items():
def main():
args = parse_option()
train_dataset, train_loader, test_dataset, test_loader = \
load_data(image_path=args.image_path, train_dir=args.train_path, test_dir=args.test_path,batch_size=args.batchsize)
print('Data Preparation : Finished')
if args.dataset == "food101":
NUM_CATEGORIES = 101
elif args.dataset == "food500":
NUM_CATEGORIES = 500
elif args.dataset == "food2k":
NUM_CATEGORIES = 2000
net = load_model('resnet50_pmg',pretrain=False,require_grad=True,num_class=NUM_CATEGORIES)
net.fc = nn.Linear(2048, 2000)
state_dict = {}
pretrained = torch.load(args.weight_path)
for k, v in net.state_dict().items():
if k[9:] in pretrained.keys() and "fc" not in k:
state_dict[k] = pretrained[k[9:]]
elif "xx" in k and re.sub(r'xx[0-9]\.?',".", k[9:]) in pretrained.keys():
......@@ -247,19 +186,19 @@ for k, v in net.state_dict().items():
state_dict[k] = v
print(k)
net.load_state_dict(state_dict)
net.fc = nn.Linear(2048,NUM_CATEGORIES)
net.load_state_dict(state_dict)
net.fc = nn.Linear(2048, NUM_CATEGORIES)
ignored_params = list(map(id, net.features.parameters()))
new_params = filter(lambda p: id(p) not in ignored_params, net.parameters())
optimizer = optim.SGD([
{'params': net.features.parameters(), 'lr': LEARNING_RATE*0.1},
{'params': new_params, 'lr': LEARNING_RATE}
],
ignored_params = list(map(id, net.features.parameters()))
new_params = filter(lambda p: id(p) not in ignored_params, net.parameters())
optimizer = optim.SGD([
{'params': net.features.parameters(), 'lr': args.learning_rate*0.1},
{'params': new_params, 'lr': args.learning_rate}
],
momentum=0.9, weight_decay=5e-4)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)
for p in optimizer.param_groups:
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)
for p in optimizer.param_groups:
outputs = ''
for k, v in p.items():
if k is 'params':
......@@ -268,23 +207,25 @@ for p in optimizer.param_groups:
outputs += (k + ': ' + str(v).ljust(10) + ' ')
print(outputs)
cudnn.benchmark = True
net.cuda()
net = nn.DataParallel(net)
cudnn.benchmark = True
net.cuda()
net = nn.DataParallel(net)
if usecheckpoint:
if args.use_checkpoint:
#net.load_state_dict(torch.load(checkpath))
net.module.load_state_dict(torch.load(checkpath).module.state_dict())
net.module.load_state_dict(torch.load(args.checkpoint).module.state_dict())
print('load the checkpoint')
train(nb_epoch=200, # number of epoch
train(nb_epoch=args.epoch, # number of epoch
trainloader=train_loader,
testloader=test_loader,
batch_size=BATCH_SIZE, # batch size
store_name='food2k_448_from2k_only_cengnei', # folder for output
batch_size=args.batchsize, # batch size
store_name='model_448_from2k', # folder for output
start_epoch=0,
net=net,
optimizer = optimizer,
exp_lr_scheduler=exp_lr_scheduler) # the start epoch number when you resume the training
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment