Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Sign in
Toggle navigation
P
prenet
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
lishen
prenet
Commits
7d06ba8c
Commit
7d06ba8c
authored
Jul 25, 2022
by
Liuyuxinict
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
tijiao
parent
3b197eb1
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
198 additions
and
169 deletions
+198
-169
vcs.xml
.idea/vcs.xml
+6
-0
workspace.xml
.idea/workspace.xml
+19
-9
data_loader.cpython-36.pyc
__pycache__/data_loader.cpython-36.pyc
+0
-0
data_loader.py
data_loader.py
+72
-0
train.py
train.py
+101
-160
No files found.
.idea/vcs.xml
0 → 100644
View file @
7d06ba8c
<?xml version="1.0" encoding="UTF-8"?>
<project
version=
"4"
>
<component
name=
"VcsDirectoryMappings"
>
<mapping
directory=
"$PROJECT_DIR$"
vcs=
"Git"
/>
</component>
</project>
\ No newline at end of file
.idea/workspace.xml
View file @
7d06ba8c
<?xml version="1.0" encoding="UTF-8"?>
<project
version=
"4"
>
<component
name=
"ChangeListManager"
>
<list
default=
"true"
id=
"1de14600-5bec-46d2-972f-11687490a303"
name=
"Default Changelist"
comment=
""
/>
<list
default=
"true"
id=
"1de14600-5bec-46d2-972f-11687490a303"
name=
"Default Changelist"
comment=
""
>
<change
afterPath=
"$PROJECT_DIR$/.idea/vcs.xml"
afterDir=
"false"
/>
<change
afterPath=
"$PROJECT_DIR$/data_loader.py"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/.idea/workspace.xml"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/.idea/workspace.xml"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/train.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/train.py"
afterDir=
"false"
/>
</list>
<option
name=
"SHOW_DIALOG"
value=
"false"
/>
<option
name=
"HIGHLIGHT_CONFLICTS"
value=
"true"
/>
<option
name=
"HIGHLIGHT_NON_ACTIVE_CHANGELIST"
value=
"false"
/>
...
...
@@ -14,7 +19,11 @@
</list>
</option>
</component>
<component
name=
"Git.Settings"
>
<option
name=
"RECENT_GIT_ROOT_PATH"
value=
"$PROJECT_DIR$"
/>
</component>
<component
name=
"ProjectId"
id=
"1hp1mgftlnNubdit1AM27vnxPWs"
/>
<component
name=
"ProjectLevelVcsManager"
settingsEditedManually=
"true"
/>
<component
name=
"ProjectViewState"
>
<option
name=
"hideEmptyMiddlePackages"
value=
"true"
/>
<option
name=
"showExcludedFiles"
value=
"true"
/>
...
...
@@ -22,6 +31,7 @@
</component>
<component
name=
"PropertiesComponent"
>
<property
name=
"RunOnceActivity.ShowReadmeOnStart"
value=
"true"
/>
<property
name=
"SHARE_PROJECT_CONFIGURATION_FILES"
value=
"true"
/>
<property
name=
"last_opened_file_path"
value=
"$PROJECT_DIR$"
/>
<property
name=
"settings.editor.selected.configurable"
value=
"com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable"
/>
</component>
...
...
@@ -120,22 +130,22 @@
<screen
x=
"0"
y=
"0"
width=
"1536"
height=
"824"
/>
</state>
<state
x=
"549"
y=
"171"
key=
"FileChooserDialogImpl/0.0.1536.824@0.0.1536.824"
timestamp=
"1658372379078"
/>
<state
width=
"1515"
height=
"2
90"
key=
"GridCell.Tab.0.bottom"
timestamp=
"1658494944285
"
>
<state
width=
"1515"
height=
"2
23"
key=
"GridCell.Tab.0.bottom"
timestamp=
"1658717893280
"
>
<screen
x=
"0"
y=
"0"
width=
"1536"
height=
"824"
/>
</state>
<state
width=
"1515"
height=
"2
90"
key=
"GridCell.Tab.0.bottom/0.0.1536.824@0.0.1536.824"
timestamp=
"1658494944285
"
/>
<state
width=
"1515"
height=
"2
90"
key=
"GridCell.Tab.0.center"
timestamp=
"1658494944285
"
>
<state
width=
"1515"
height=
"2
23"
key=
"GridCell.Tab.0.bottom/0.0.1536.824@0.0.1536.824"
timestamp=
"1658717893280
"
/>
<state
width=
"1515"
height=
"2
23"
key=
"GridCell.Tab.0.center"
timestamp=
"1658717893280
"
>
<screen
x=
"0"
y=
"0"
width=
"1536"
height=
"824"
/>
</state>
<state
width=
"1515"
height=
"2
90"
key=
"GridCell.Tab.0.center/0.0.1536.824@0.0.1536.824"
timestamp=
"1658494944285
"
/>
<state
width=
"1515"
height=
"2
90"
key=
"GridCell.Tab.0.left"
timestamp=
"1658494944285
"
>
<state
width=
"1515"
height=
"2
23"
key=
"GridCell.Tab.0.center/0.0.1536.824@0.0.1536.824"
timestamp=
"1658717893280
"
/>
<state
width=
"1515"
height=
"2
23"
key=
"GridCell.Tab.0.left"
timestamp=
"1658717893279
"
>
<screen
x=
"0"
y=
"0"
width=
"1536"
height=
"824"
/>
</state>
<state
width=
"1515"
height=
"2
90"
key=
"GridCell.Tab.0.left/0.0.1536.824@0.0.1536.824"
timestamp=
"1658494944285
"
/>
<state
width=
"1515"
height=
"2
90"
key=
"GridCell.Tab.0.right"
timestamp=
"1658494944285
"
>
<state
width=
"1515"
height=
"2
23"
key=
"GridCell.Tab.0.left/0.0.1536.824@0.0.1536.824"
timestamp=
"1658717893279
"
/>
<state
width=
"1515"
height=
"2
23"
key=
"GridCell.Tab.0.right"
timestamp=
"1658717893280
"
>
<screen
x=
"0"
y=
"0"
width=
"1536"
height=
"824"
/>
</state>
<state
width=
"1515"
height=
"2
90"
key=
"GridCell.Tab.0.right/0.0.1536.824@0.0.1536.824"
timestamp=
"1658494944285
"
/>
<state
width=
"1515"
height=
"2
23"
key=
"GridCell.Tab.0.right/0.0.1536.824@0.0.1536.824"
timestamp=
"1658717893280
"
/>
<state
width=
"1515"
height=
"290"
key=
"GridCell.Tab.1.bottom"
timestamp=
"1658494944285"
>
<screen
x=
"0"
y=
"0"
width=
"1536"
height=
"824"
/>
</state>
...
...
__pycache__/data_loader.cpython-36.pyc
0 → 100644
View file @
7d06ba8c
File added
data_loader.py
0 → 100644
View file @
7d06ba8c
import
torch
import
PIL
from
PIL
import
Image
import
torch.utils.data
as
data
from
torchvision
import
datasets
,
transforms
def
My_loader
(
path
):
return
PIL
.
Image
.
open
(
path
)
.
convert
(
'RGB'
)
class
MyDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
txt_dir
,
image_path
,
transform
=
None
,
target_transform
=
None
,
loader
=
My_loader
):
data_txt
=
open
(
txt_dir
,
'r'
)
imgs
=
[]
for
line
in
data_txt
:
line
=
line
.
strip
()
words
=
line
.
split
(
' '
)
imgs
.
append
((
words
[
0
],
int
(
words
[
1
]
.
strip
())))
self
.
imgs
=
imgs
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
loader
=
My_loader
self
.
image_path
=
image_path
def
__len__
(
self
):
return
len
(
self
.
imgs
)
def
__getitem__
(
self
,
index
):
img_name
,
label
=
self
.
imgs
[
index
]
# label = list(map(int, label))
# print label
# print type(label)
#img = self.loader('/home/vipl/llh/food101_finetuning/food101_vgg/origal_data/images/'+img_name.replace("\\","/"))
img
=
self
.
loader
(
self
.
image_path
+
img_name
)
# print img
if
self
.
transform
is
not
None
:
img
=
self
.
transform
(
img
)
# print img.size()
# label =torch.Tensor(label)
# print label.size()
return
img
,
label
# if the label is the single-label it can be the int
# if the multilabel can be the list to torch.tensor
def
load_data
(
image_path
,
train_dir
,
test_dir
,
batch_size
):
normalize
=
transforms
.
Normalize
(
mean
=
[
0.5457954
,
0.44430383
,
0.34424934
],
std
=
[
0.23273608
,
0.24383051
,
0.24237761
])
train_transforms
=
transforms
.
Compose
([
transforms
.
RandomHorizontalFlip
(
p
=
0.5
),
# default value is 0.5
transforms
.
RandomRotation
(
degrees
=
15
),
transforms
.
ColorJitter
(
brightness
=
0.126
,
saturation
=
0.5
),
transforms
.
Resize
((
550
,
550
)),
transforms
.
RandomCrop
(
448
),
transforms
.
ToTensor
(),
normalize
])
# transforms of test dataset
test_transforms
=
transforms
.
Compose
([
transforms
.
Resize
((
550
,
550
)),
transforms
.
CenterCrop
((
448
,
448
)),
transforms
.
ToTensor
(),
normalize
])
train_dataset
=
MyDataset
(
txt_dir
=
train_dir
,
image_path
=
image_path
,
transform
=
train_transforms
)
test_dataset
=
MyDataset
(
txt_dir
=
test_dir
,
image_path
=
image_path
,
transform
=
test_transforms
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
0
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
test_dataset
,
batch_size
=
batch_size
//
2
,
shuffle
=
False
,
num_workers
=
0
)
return
train_dataset
,
train_loader
,
test_dataset
,
test_loader
train.py
View file @
7d06ba8c
from
__future__
import
print_function
import
os
from
PIL
import
Image
import
torch.utils.data
as
data
import
os
import
PIL
import
argparse
from
tqdm
import
tqdm
import
torch.optim
as
optim
from
data_loader
import
load_data
from
torch.optim
import
lr_scheduler
import
torch.backends.cudnn
as
cudnn
import
re
from
utils
import
*
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"0"
EPOCH
=
200
# number of times for each run-through
BATCH_SIZE
=
2
# number of images for each epoch
LEARNING_RATE
=
0.0001
# default learning rate
GPU_IN_USE
=
torch
.
cuda
.
is_available
()
# whether using GPU
DIR_TRAIN_IMAGES
=
r'E:\datasets\food101\meta_data\train_full.txt'
#DIR_TRAIN_IMAGES = "/home/vipl/lyx/train_full.txt"
DIR_TEST_IMAGES
=
r'E:\datasets\food101\meta_data\test_full.txt'
#DIR_TEST_IMAGES = "/home/vipl/lyx/test_full.txt"
Image_path
=
r"E:/datasets/food101/images/"
#Image_path = "/home/vipl/lizhuo/dataset_food/food101/images/"
#NUM_CATEGORIES = 500
NUM_CATEGORIES
=
101
#WEIGHT_PATH= '/home/vipl/lyx/resnet50.pth'
WEIGHT_PATH
=
r'E:/Pretrained_model/food2k_resnet50_0.0001.pth'
checkpoint
=
''
useJP
=
False
#use Jigsaw Patches during PMG food2k_448_from2k_only_cengnei
usecheckpoint
=
False
checkpath
=
"./food2k_448_from2k_only_cengnei/model.pth"
useAttn
=
True
normalize
=
transforms
.
Normalize
(
mean
=
[
0.5457954
,
0.44430383
,
0.34424934
],
std
=
[
0.23273608
,
0.24383051
,
0.24237761
])
train_transforms
=
transforms
.
Compose
([
transforms
.
RandomHorizontalFlip
(
p
=
0.5
),
# default value is 0.5
transforms
.
RandomRotation
(
degrees
=
15
),
transforms
.
ColorJitter
(
brightness
=
0.126
,
saturation
=
0.5
),
transforms
.
Resize
((
550
,
550
)),
transforms
.
RandomCrop
(
448
),
transforms
.
ToTensor
(),
normalize
])
# transforms of test dataset
test_transforms
=
transforms
.
Compose
([
transforms
.
Resize
((
550
,
550
)),
transforms
.
CenterCrop
((
448
,
448
)),
transforms
.
ToTensor
(),
normalize
])
def
My_loader
(
path
):
return
PIL
.
Image
.
open
(
path
)
.
convert
(
'RGB'
)
class
MyDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
txt_dir
,
transform
=
None
,
target_transform
=
None
,
loader
=
My_loader
):
data_txt
=
open
(
txt_dir
,
'r'
)
imgs
=
[]
for
line
in
data_txt
:
line
=
line
.
strip
()
words
=
line
.
split
(
' '
)
imgs
.
append
((
words
[
0
],
int
(
words
[
1
]
.
strip
())))
self
.
imgs
=
imgs
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
loader
=
My_loader
def
__len__
(
self
):
return
len
(
self
.
imgs
)
def
__getitem__
(
self
,
index
):
img_name
,
label
=
self
.
imgs
[
index
]
# label = list(map(int, label))
# print label
# print type(label)
#img = self.loader('/home/vipl/llh/food101_finetuning/food101_vgg/origal_data/images/'+img_name.replace("\\","/"))
img
=
self
.
loader
(
Image_path
+
img_name
)
# print img
if
self
.
transform
is
not
None
:
img
=
self
.
transform
(
img
)
# print img.size()
# label =torch.Tensor(label)
# print label.size()
return
img
,
label
# if the label is the single-label it can be the int
# if the multilabel can be the list to torch.tensor
train_dataset
=
MyDataset
(
txt_dir
=
DIR_TRAIN_IMAGES
,
transform
=
train_transforms
)
test_dataset
=
MyDataset
(
txt_dir
=
DIR_TEST_IMAGES
,
transform
=
test_transforms
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
train_dataset
,
batch_size
=
BATCH_SIZE
,
shuffle
=
True
,
num_workers
=
0
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
test_dataset
,
batch_size
=
BATCH_SIZE
//
2
,
shuffle
=
False
,
num_workers
=
0
)
print
(
'Data Preparation : Finished'
)
def
parse_option
():
parser
=
argparse
.
ArgumentParser
(
'Progressive Region Enhancement Network(PRENet) for training and testing'
)
parser
.
add_argument
(
'--batchsize'
,
default
=
2
,
type
=
int
,
help
=
"batch size for single GPU"
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
'food101'
,
help
=
'food2k, food101, food500'
)
parser
.
add_argument
(
'--image_path'
,
type
=
str
,
default
=
"E:/datasets/food101/images/"
,
help
=
'path to dataset'
)
parser
.
add_argument
(
"--train_path"
,
type
=
str
,
default
=
"E:/datasets/food101/meta_data/train_full.txt"
,
help
=
'path to training list'
)
parser
.
add_argument
(
"--test_path"
,
type
=
str
,
default
=
"E:/datasets/food101/meta_data/test_full.txt"
,
help
=
'path to testing list'
)
parser
.
add_argument
(
'--weight_path'
,
default
=
"E:/Pretrained_model/food2k_resnet50_0.0001.pth"
,
help
=
'path to the pretrained model'
)
parser
.
add_argument
(
'--use_checkpoint'
,
action
=
'store_true'
,
default
=
False
,
help
=
"whether to use gradient checkpointing to save memory"
)
parser
.
add_argument
(
'--checkpoint'
,
type
=
str
,
default
=
None
,
help
=
"the path to checkpoint"
)
parser
.
add_argument
(
'--output_dir'
,
default
=
'output'
,
type
=
str
,
metavar
=
'PATH'
,
help
=
'root of output folder, the full path is <output>/<model_name>/<tag> (default: output)'
)
parser
.
add_argument
(
"--learning_rate"
,
default
=
1e-4
,
type
=
float
,
help
=
"The initial learning rate for SGD."
)
parser
.
add_argument
(
"--epoch"
,
default
=
200
,
type
=
int
,
help
=
"The number of epochs."
)
args
,
unparsed
=
parser
.
parse_known_args
()
return
args
def
train
(
nb_epoch
,
trainloader
,
testloader
,
batch_size
,
store_name
,
start_epoch
,
net
,
optimizer
,
exp_lr_scheduler
):
exp_dir
=
store_name
...
...
@@ -140,10 +77,6 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
# Step 1
optimizer
.
zero_grad
()
#inputs1 = jigsaw_generator(inputs, 8)
inputs1
=
None
if
useJP
:
_
,
_
,
_
,
_
,
output_1
,
_
,
_
=
net
(
inputs1
,
False
)
else
:
_
,
_
,
_
,
_
,
output_1
,
_
,
_
=
net
(
inputs
,
False
)
#print(output_1.shape)
loss1
=
CELoss
(
output_1
,
targets
)
*
1
...
...
@@ -153,10 +86,7 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
# Step 2
optimizer
.
zero_grad
()
#inputs2 = jigsaw_generator(inputs, 4)
inputs2
=
None
if
useJP
:
_
,
_
,
_
,
_
,
_
,
output_2
,
_
,
=
net
(
inputs2
,
False
)
else
:
_
,
_
,
_
,
_
,
_
,
output_2
,
_
,
=
net
(
inputs
,
False
)
#print(output_2.shape)
loss2
=
CELoss
(
output_2
,
targets
)
*
1
...
...
@@ -166,10 +96,6 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
# Step 3
optimizer
.
zero_grad
()
#inputs3 = jigsaw_generator(inputs, 2)
inputs3
=
None
if
useJP
:
_
,
_
,
_
,
_
,
_
,
_
,
output_3
=
net
(
inputs3
,
False
)
else
:
_
,
_
,
_
,
_
,
_
,
_
,
output_3
=
net
(
inputs
,
False
)
#print(output_3.shape)
loss3
=
CELoss
(
output_3
,
targets
)
*
1
...
...
@@ -178,7 +104,7 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
optimizer
.
zero_grad
()
x1
,
x2
,
x3
,
output_concat
,
_
,
_
,
_
=
net
(
inputs
,
useAttn
)
x1
,
x2
,
x3
,
output_concat
,
_
,
_
,
_
=
net
(
inputs
,
True
)
concat_loss
=
CELoss
(
output_concat
,
targets
)
*
2
...
...
@@ -233,12 +159,25 @@ def train(nb_epoch, trainloader, testloader, batch_size, store_name, start_epoch
'Iteration
%
d, top1 =
%.5
f, top5 =
%.5
f, top1_combined =
%.5
f, top5_combined =
%.5
f, test_loss =
%.6
f
\n
'
%
(
epoch
,
val_acc
,
val5_acc
,
val_acc_com
,
val5_acc_com
,
val_loss
))
net
=
load_model
(
'resnet50_pmg'
,
pretrain
=
False
,
require_grad
=
True
,
num_class
=
NUM_CATEGORIES
)
net
.
fc
=
nn
.
Linear
(
2048
,
2000
)
state_dict
=
{}
pretrained
=
torch
.
load
(
WEIGHT_PATH
)
for
k
,
v
in
net
.
state_dict
()
.
items
():
def
main
():
args
=
parse_option
()
train_dataset
,
train_loader
,
test_dataset
,
test_loader
=
\
load_data
(
image_path
=
args
.
image_path
,
train_dir
=
args
.
train_path
,
test_dir
=
args
.
test_path
,
batch_size
=
args
.
batchsize
)
print
(
'Data Preparation : Finished'
)
if
args
.
dataset
==
"food101"
:
NUM_CATEGORIES
=
101
elif
args
.
dataset
==
"food500"
:
NUM_CATEGORIES
=
500
elif
args
.
dataset
==
"food2k"
:
NUM_CATEGORIES
=
2000
net
=
load_model
(
'resnet50_pmg'
,
pretrain
=
False
,
require_grad
=
True
,
num_class
=
NUM_CATEGORIES
)
net
.
fc
=
nn
.
Linear
(
2048
,
2000
)
state_dict
=
{}
pretrained
=
torch
.
load
(
args
.
weight_path
)
for
k
,
v
in
net
.
state_dict
()
.
items
():
if
k
[
9
:]
in
pretrained
.
keys
()
and
"fc"
not
in
k
:
state_dict
[
k
]
=
pretrained
[
k
[
9
:]]
elif
"xx"
in
k
and
re
.
sub
(
r'xx[0-9]\.?'
,
"."
,
k
[
9
:])
in
pretrained
.
keys
():
...
...
@@ -247,19 +186,19 @@ for k, v in net.state_dict().items():
state_dict
[
k
]
=
v
print
(
k
)
net
.
load_state_dict
(
state_dict
)
net
.
fc
=
nn
.
Linear
(
2048
,
NUM_CATEGORIES
)
net
.
load_state_dict
(
state_dict
)
net
.
fc
=
nn
.
Linear
(
2048
,
NUM_CATEGORIES
)
ignored_params
=
list
(
map
(
id
,
net
.
features
.
parameters
()))
new_params
=
filter
(
lambda
p
:
id
(
p
)
not
in
ignored_params
,
net
.
parameters
())
optimizer
=
optim
.
SGD
([
{
'params'
:
net
.
features
.
parameters
(),
'lr'
:
LEARNING_RATE
*
0.1
},
{
'params'
:
new_params
,
'lr'
:
LEARNING_RATE
}
],
ignored_params
=
list
(
map
(
id
,
net
.
features
.
parameters
()))
new_params
=
filter
(
lambda
p
:
id
(
p
)
not
in
ignored_params
,
net
.
parameters
())
optimizer
=
optim
.
SGD
([
{
'params'
:
net
.
features
.
parameters
(),
'lr'
:
args
.
learning_rate
*
0.1
},
{
'params'
:
new_params
,
'lr'
:
args
.
learning_rate
}
],
momentum
=
0.9
,
weight_decay
=
5e-4
)
exp_lr_scheduler
=
lr_scheduler
.
StepLR
(
optimizer
,
step_size
=
2
,
gamma
=
0.9
)
for
p
in
optimizer
.
param_groups
:
exp_lr_scheduler
=
lr_scheduler
.
StepLR
(
optimizer
,
step_size
=
2
,
gamma
=
0.9
)
for
p
in
optimizer
.
param_groups
:
outputs
=
''
for
k
,
v
in
p
.
items
():
if
k
is
'params'
:
...
...
@@ -268,23 +207,25 @@ for p in optimizer.param_groups:
outputs
+=
(
k
+
': '
+
str
(
v
)
.
ljust
(
10
)
+
' '
)
print
(
outputs
)
cudnn
.
benchmark
=
True
net
.
cuda
()
net
=
nn
.
DataParallel
(
net
)
cudnn
.
benchmark
=
True
net
.
cuda
()
net
=
nn
.
DataParallel
(
net
)
if
use
checkpoint
:
if
args
.
use_
checkpoint
:
#net.load_state_dict(torch.load(checkpath))
net
.
module
.
load_state_dict
(
torch
.
load
(
checkpath
)
.
module
.
state_dict
())
net
.
module
.
load_state_dict
(
torch
.
load
(
args
.
checkpoint
)
.
module
.
state_dict
())
print
(
'load the checkpoint'
)
train
(
nb_epoch
=
200
,
# number of epoch
train
(
nb_epoch
=
args
.
epoch
,
# number of epoch
trainloader
=
train_loader
,
testloader
=
test_loader
,
batch_size
=
BATCH_SIZE
,
# batch size
store_name
=
'food2k_448_from2k_only_cengnei
'
,
# folder for output
batch_size
=
args
.
batchsize
,
# batch size
store_name
=
'model_448_from2k
'
,
# folder for output
start_epoch
=
0
,
net
=
net
,
optimizer
=
optimizer
,
exp_lr_scheduler
=
exp_lr_scheduler
)
# the start epoch number when you resume the training
if
__name__
==
"__main__"
:
main
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment