Commit b905a8c5 authored by Liuyuxinict's avatar Liuyuxinict

update725

parent 8ff46ec6
...@@ -2,10 +2,8 @@ ...@@ -2,10 +2,8 @@
<project version="4"> <project version="4">
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="1de14600-5bec-46d2-972f-11687490a303" name="Default Changelist" comment=""> <list default="true" id="1de14600-5bec-46d2-972f-11687490a303" name="Default Changelist" comment="">
<change afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
<change afterPath="$PROJECT_DIR$/data_loader.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" /> <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/train.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/main.py" afterDir="false" />
</list> </list>
<option name="SHOW_DIALOG" value="false" /> <option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" /> <option name="HIGHLIGHT_CONFLICTS" value="true" />
...@@ -40,7 +38,7 @@ ...@@ -40,7 +38,7 @@
<recent name="D:\PMG" /> <recent name="D:\PMG" />
</key> </key>
</component> </component>
<component name="RunManager" selected="Python.train"> <component name="RunManager" selected="Python.main">
<configuration name="1" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true"> <configuration name="1" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="PMG-Progressive-Multi-Granularity-Training-master" /> <module name="PMG-Progressive-Multi-Granularity-Training-master" />
<option name="INTERPRETER_OPTIONS" value="" /> <option name="INTERPRETER_OPTIONS" value="" />
...@@ -62,7 +60,7 @@ ...@@ -62,7 +60,7 @@
<option name="INPUT_FILE" value="" /> <option name="INPUT_FILE" value="" />
<method v="2" /> <method v="2" />
</configuration> </configuration>
<configuration name="train" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true"> <configuration name="main" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="PMG-Progressive-Multi-Granularity-Training-master" /> <module name="PMG-Progressive-Multi-Granularity-Training-master" />
<option name="INTERPRETER_OPTIONS" value="" /> <option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" /> <option name="PARENT_ENVS" value="true" />
...@@ -74,7 +72,7 @@ ...@@ -74,7 +72,7 @@
<option name="IS_MODULE_SDK" value="true" /> <option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" /> <option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" /> <option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/train.py" /> <option name="SCRIPT_NAME" value="C:\Users\刘宇昕\prenet\main.py" />
<option name="PARAMETERS" value="" /> <option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" /> <option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" /> <option name="EMULATE_TERMINAL" value="false" />
...@@ -106,7 +104,7 @@ ...@@ -106,7 +104,7 @@
</configuration> </configuration>
<recent_temporary> <recent_temporary>
<list> <list>
<item itemvalue="Python.train" /> <item itemvalue="Python.main" />
<item itemvalue="Python.visualization" /> <item itemvalue="Python.visualization" />
<item itemvalue="Python.1" /> <item itemvalue="Python.1" />
</list> </list>
......
...@@ -34,6 +34,8 @@ def parse_option(): ...@@ -34,6 +34,8 @@ def parse_option():
help="The initial learning rate for SGD.") help="The initial learning rate for SGD.")
parser.add_argument("--epoch", default=200, type=int, parser.add_argument("--epoch", default=200, type=int,
help="The number of epochs.") help="The number of epochs.")
parser.add_argument("--test", action='store_true', default=True,
help="Testing model.")
args, unparsed = parser.parse_known_args() args, unparsed = parser.parse_known_args()
return args return args
...@@ -150,7 +152,7 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch ...@@ -150,7 +152,7 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
epoch, train_acc, train_loss, train_loss1 / (idx + 1), train_loss2 / (idx + 1), train_loss3 / (idx + 1), epoch, train_acc, train_loss, train_loss1 / (idx + 1), train_loss2 / (idx + 1), train_loss3 / (idx + 1),
train_loss4 / (idx + 1))) train_loss4 / (idx + 1)))
val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss = test(net, CELoss, batch_size, testloader,useAttn) val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss = test(net, CELoss, batch_size, testloader,True)
if val_acc > max_val_acc: if val_acc > max_val_acc:
max_val_acc = val_acc max_val_acc = val_acc
torch.save(net, './' + store_name + '/model.pth') torch.save(net, './' + store_name + '/model.pth')
...@@ -216,6 +218,12 @@ def main(): ...@@ -216,6 +218,12 @@ def main():
net.module.load_state_dict(torch.load(args.checkpoint).module.state_dict()) net.module.load_state_dict(torch.load(args.checkpoint).module.state_dict())
print('load the checkpoint') print('load the checkpoint')
if args.test:
val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss = test(net, nn.CrossEntropyLoss(), args.batchsize, test_loader, True)
print('Accuracy of the network on the val images: top1 = %.5f, top5 = %.5f, top1_combined = %.5f, top5_combined = %.5f, test_loss = %.6f\n' % (
val_acc, val5_acc, val_acc_com, val5_acc_com, val_loss))
return
train(nb_epoch=args.epoch, # number of epoch train(nb_epoch=args.epoch, # number of epoch
trainloader=train_loader, trainloader=train_loader,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment