Commit 2c86b7c2 authored by Liuyuxinict's avatar Liuyuxinict

update725

parent b905a8c5
......@@ -2,5 +2,6 @@
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
<mapping directory="$PROJECT_DIR$/prenet" vcs="Git" />
</component>
</project>
\ No newline at end of file
......@@ -3,7 +3,9 @@
<component name="ChangeListManager">
<list default="true" id="1de14600-5bec-46d2-972f-11687490a303" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/main.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/main.py" beforeDir="false" afterPath="$PROJECT_DIR$/main.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/model.py" beforeDir="false" afterPath="$PROJECT_DIR$/model.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/utils.py" beforeDir="false" afterPath="$PROJECT_DIR$/utils.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
......@@ -128,22 +130,22 @@
<screen x="0" y="0" width="1536" height="824" />
</state>
<state x="549" y="171" key="FileChooserDialogImpl/0.0.1536.824@0.0.1536.824" timestamp="1658372379078" />
<state width="1515" height="223" key="GridCell.Tab.0.bottom" timestamp="1658717893280">
<state width="1515" height="211" key="GridCell.Tab.0.bottom" timestamp="1658752201567">
<screen x="0" y="0" width="1536" height="824" />
</state>
<state width="1515" height="223" key="GridCell.Tab.0.bottom/0.0.1536.824@0.0.1536.824" timestamp="1658717893280" />
<state width="1515" height="223" key="GridCell.Tab.0.center" timestamp="1658717893280">
<state width="1515" height="211" key="GridCell.Tab.0.bottom/0.0.1536.824@0.0.1536.824" timestamp="1658752201567" />
<state width="1515" height="211" key="GridCell.Tab.0.center" timestamp="1658752201566">
<screen x="0" y="0" width="1536" height="824" />
</state>
<state width="1515" height="223" key="GridCell.Tab.0.center/0.0.1536.824@0.0.1536.824" timestamp="1658717893280" />
<state width="1515" height="223" key="GridCell.Tab.0.left" timestamp="1658717893279">
<state width="1515" height="211" key="GridCell.Tab.0.center/0.0.1536.824@0.0.1536.824" timestamp="1658752201566" />
<state width="1515" height="211" key="GridCell.Tab.0.left" timestamp="1658752201566">
<screen x="0" y="0" width="1536" height="824" />
</state>
<state width="1515" height="223" key="GridCell.Tab.0.left/0.0.1536.824@0.0.1536.824" timestamp="1658717893279" />
<state width="1515" height="223" key="GridCell.Tab.0.right" timestamp="1658717893280">
<state width="1515" height="211" key="GridCell.Tab.0.left/0.0.1536.824@0.0.1536.824" timestamp="1658752201566" />
<state width="1515" height="211" key="GridCell.Tab.0.right" timestamp="1658752201567">
<screen x="0" y="0" width="1536" height="824" />
</state>
<state width="1515" height="223" key="GridCell.Tab.0.right/0.0.1536.824@0.0.1536.824" timestamp="1658717893280" />
<state width="1515" height="211" key="GridCell.Tab.0.right/0.0.1536.824@0.0.1536.824" timestamp="1658752201567" />
<state width="1515" height="290" key="GridCell.Tab.1.bottom" timestamp="1658494944285">
<screen x="0" y="0" width="1536" height="824" />
</state>
......
......@@ -24,9 +24,9 @@ def parse_option():
parser.add_argument("--test_path", type=str, default="E:/datasets/food101/meta_data/test_full.txt",
help='path to testing list')
parser.add_argument('--weight_path', default="E:/Pretrained_model/food2k_resnet50_0.0001.pth", help='path to the pretrained model')
parser.add_argument('--use_checkpoint', action='store_true', default=False,
parser.add_argument('--use_checkpoint', action='store_true', default=True,
help="whether to use gradient checkpointing to save memory")
parser.add_argument('--checkpoint', type=str, default=None,
parser.add_argument('--checkpoint', type=str, default="E:/Pretrained_model/model.pth",
help="the path to checkpoint")
parser.add_argument('--output_dir', default='output', type=str, metavar='PATH',
help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
......@@ -174,7 +174,7 @@ def main():
NUM_CATEGORIES = 2000
net = load_model('resnet50_pmg',pretrain=False,require_grad=True,num_class=NUM_CATEGORIES)
net = load_model('resnet50',pretrain=False,require_grad=True,num_class=NUM_CATEGORIES)
net.fc = nn.Linear(2048, 2000)
state_dict = {}
pretrained = torch.load(args.weight_path)
......@@ -215,6 +215,8 @@ def main():
if args.use_checkpoint:
#net.load_state_dict(torch.load(checkpath))
model = torch.load(args.checkpoint).module.state_dict()
net.module.load_state_dict(torch.load(args.checkpoint).module.state_dict())
print('load the checkpoint')
......
......@@ -6,9 +6,9 @@ from layer_self_attention import layer_self_attention
from dropblock import DropBlock2D
import numpy as np
class PMG(nn.Module):
class PRENet(nn.Module):
def __init__(self, model, feature_size, classes_num):
super(PMG, self).__init__()
super(PRENet, self).__init__()
self.features = model
......
prenet @ 5aaf02d3
Subproject commit 5aaf02d3935c8777f42f158c2edc68bdbbd89880
......@@ -11,11 +11,11 @@ from Resnet import *
def load_model(model_name, pretrain=True, require_grad=True, num_class=1000, pretrained_model=None):
print('==> Building model..')
if model_name == 'resnet50_pmg':
if model_name == 'resnet50':
net = resnet50(pretrained=pretrain, path=pretrained_model)
#for param in net.parameters():
#param.requires_grad = require_grad
net = PMG(net, 512, num_class)
net = PRENet(net, 512, num_class)
return net
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment