Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Sign in
Toggle navigation
P
prenet
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
lishen
prenet
Commits
5c376a4c
Commit
5c376a4c
authored
May 29, 2023
by
lishen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[fix]
parent
23f12127
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
4 deletions
+9
-4
main.py
main.py
+8
-4
self_attention.py
self_attention.py
+1
-0
No files found.
main.py
View file @
5c376a4c
...
@@ -12,7 +12,8 @@ import torch.backends.cudnn as cudnn
...
@@ -12,7 +12,8 @@ import torch.backends.cudnn as cudnn
import
re
import
re
from
utils
import
*
from
utils
import
*
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"0,1,2,3,4"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"0"
def
parse_option
():
def
parse_option
():
...
@@ -35,7 +36,7 @@ def parse_option():
...
@@ -35,7 +36,7 @@ def parse_option():
help
=
"The initial learning rate for SGD."
)
help
=
"The initial learning rate for SGD."
)
parser
.
add_argument
(
"--epoch"
,
default
=
200
,
type
=
int
,
parser
.
add_argument
(
"--epoch"
,
default
=
200
,
type
=
int
,
help
=
"The number of epochs."
)
help
=
"The number of epochs."
)
parser
.
add_argument
(
"--test"
,
action
=
'store_true'
,
default
=
Tru
e
,
parser
.
add_argument
(
"--test"
,
action
=
'store_true'
,
default
=
Fals
e
,
help
=
"Testing model."
)
help
=
"Testing model."
)
args
,
unparsed
=
parser
.
parse_known_args
()
args
,
unparsed
=
parser
.
parse_known_args
()
return
args
return
args
...
@@ -173,6 +174,8 @@ def main():
...
@@ -173,6 +174,8 @@ def main():
NUM_CATEGORIES
=
2000
NUM_CATEGORIES
=
2000
elif
args
.
dataset
==
"jkyy"
:
elif
args
.
dataset
==
"jkyy"
:
NUM_CATEGORIES
=
1788
NUM_CATEGORIES
=
1788
elif
args
.
dataset
==
"test"
:
NUM_CATEGORIES
=
5
net
=
load_model
(
'resnet50'
,
pretrain
=
False
,
require_grad
=
True
,
num_class
=
NUM_CATEGORIES
)
net
=
load_model
(
'resnet50'
,
pretrain
=
False
,
require_grad
=
True
,
num_class
=
NUM_CATEGORIES
)
net
.
fc
=
nn
.
Linear
(
2048
,
2000
)
net
.
fc
=
nn
.
Linear
(
2048
,
2000
)
...
@@ -203,7 +206,7 @@ def main():
...
@@ -203,7 +206,7 @@ def main():
for
p
in
optimizer
.
param_groups
:
for
p
in
optimizer
.
param_groups
:
outputs
=
''
outputs
=
''
for
k
,
v
in
p
.
items
():
for
k
,
v
in
p
.
items
():
if
k
is
'params'
:
if
k
==
'params'
:
outputs
+=
(
k
+
': '
+
str
(
v
[
0
]
.
shape
)
.
ljust
(
30
)
+
' '
)
outputs
+=
(
k
+
': '
+
str
(
v
[
0
]
.
shape
)
.
ljust
(
30
)
+
' '
)
else
:
else
:
outputs
+=
(
k
+
': '
+
str
(
v
)
.
ljust
(
10
)
+
' '
)
outputs
+=
(
k
+
': '
+
str
(
v
)
.
ljust
(
10
)
+
' '
)
...
@@ -211,7 +214,8 @@ def main():
...
@@ -211,7 +214,8 @@ def main():
cudnn
.
benchmark
=
True
cudnn
.
benchmark
=
True
net
.
cuda
()
net
.
cuda
()
device_ids
=
[
0
,
1
,
2
,
3
,
4
]
# device_ids = [0, 1, 2, 3, 4]
device_ids
=
[
0
]
# net = nn.DataParallel(net).to(device_ids)
# net = nn.DataParallel(net).to(device_ids)
net
=
nn
.
DataParallel
(
net
,
device_ids
=
device_ids
)
net
=
nn
.
DataParallel
(
net
,
device_ids
=
device_ids
)
# optimizer = nn.DataParallel(optimizer, device_ids=device_ids)
# optimizer = nn.DataParallel(optimizer, device_ids=device_ids)
...
...
self_attention.py
View file @
5c376a4c
...
@@ -41,6 +41,7 @@ class self_attention(nn.Module):
...
@@ -41,6 +41,7 @@ class self_attention(nn.Module):
v
=
self
.
split_heads_2d
(
v
,
Nh
)
v
=
self
.
split_heads_2d
(
v
,
Nh
)
dkh
=
dk
//
Nh
dkh
=
dk
//
Nh
q
=
q
.
clone
()
q
*=
dkh
**
-
0.5
q
*=
dkh
**
-
0.5
flat_q
=
torch
.
reshape
(
q
,
(
N
,
Nh
,
dq
//
Nh
,
H
*
W
))
flat_q
=
torch
.
reshape
(
q
,
(
N
,
Nh
,
dq
//
Nh
,
H
*
W
))
flat_k
=
torch
.
reshape
(
k
,
(
N
,
Nh
,
dk
//
Nh
,
H
*
W
))
flat_k
=
torch
.
reshape
(
k
,
(
N
,
Nh
,
dk
//
Nh
,
H
*
W
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment