Commit 5c376a4c authored by lishen's avatar lishen

[fix]

parent 23f12127
...@@ -12,7 +12,8 @@ import torch.backends.cudnn as cudnn ...@@ -12,7 +12,8 @@ import torch.backends.cudnn as cudnn
import re import re
from utils import * from utils import *
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4" # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def parse_option(): def parse_option():
...@@ -35,7 +36,7 @@ def parse_option(): ...@@ -35,7 +36,7 @@ def parse_option():
help="The initial learning rate for SGD.") help="The initial learning rate for SGD.")
parser.add_argument("--epoch", default=200, type=int, parser.add_argument("--epoch", default=200, type=int,
help="The number of epochs.") help="The number of epochs.")
parser.add_argument("--test", action='store_true', default=True, parser.add_argument("--test", action='store_true', default=False,
help="Testing model.") help="Testing model.")
args, unparsed = parser.parse_known_args() args, unparsed = parser.parse_known_args()
return args return args
...@@ -173,6 +174,8 @@ def main(): ...@@ -173,6 +174,8 @@ def main():
NUM_CATEGORIES = 2000 NUM_CATEGORIES = 2000
elif args.dataset == "jkyy": elif args.dataset == "jkyy":
NUM_CATEGORIES = 1788 NUM_CATEGORIES = 1788
elif args.dataset == "test":
NUM_CATEGORIES = 5
net = load_model('resnet50', pretrain=False, require_grad=True, num_class=NUM_CATEGORIES) net = load_model('resnet50', pretrain=False, require_grad=True, num_class=NUM_CATEGORIES)
net.fc = nn.Linear(2048, 2000) net.fc = nn.Linear(2048, 2000)
...@@ -203,7 +206,7 @@ def main(): ...@@ -203,7 +206,7 @@ def main():
for p in optimizer.param_groups: for p in optimizer.param_groups:
outputs = '' outputs = ''
for k, v in p.items(): for k, v in p.items():
if k is 'params': if k == 'params':
outputs += (k + ': ' + str(v[0].shape).ljust(30) + ' ') outputs += (k + ': ' + str(v[0].shape).ljust(30) + ' ')
else: else:
outputs += (k + ': ' + str(v).ljust(10) + ' ') outputs += (k + ': ' + str(v).ljust(10) + ' ')
...@@ -211,7 +214,8 @@ def main(): ...@@ -211,7 +214,8 @@ def main():
cudnn.benchmark = True cudnn.benchmark = True
net.cuda() net.cuda()
device_ids = [0, 1, 2, 3, 4] # device_ids = [0, 1, 2, 3, 4]
device_ids = [0]
# net = nn.DataParallel(net).to(device_ids) # net = nn.DataParallel(net).to(device_ids)
net = nn.DataParallel(net, device_ids=device_ids) net = nn.DataParallel(net, device_ids=device_ids)
# optimizer = nn.DataParallel(optimizer, device_ids=device_ids) # optimizer = nn.DataParallel(optimizer, device_ids=device_ids)
......
...@@ -41,6 +41,7 @@ class self_attention(nn.Module): ...@@ -41,6 +41,7 @@ class self_attention(nn.Module):
v = self.split_heads_2d(v, Nh) v = self.split_heads_2d(v, Nh)
dkh = dk // Nh dkh = dk // Nh
q = q.clone()
q *= dkh ** -0.5 q *= dkh ** -0.5
flat_q = torch.reshape(q, (N, Nh, dq // Nh, H * W)) flat_q = torch.reshape(q, (N, Nh, dq // Nh, H * W))
flat_k = torch.reshape(k, (N, Nh, dk // Nh, H * W)) flat_k = torch.reshape(k, (N, Nh, dk // Nh, H * W))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment