CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (2024)

はじめに このシリーズでは、コンピューター ビジョンにおけるディープ ラーニングのさまざまな古典的なネットワーク モデル(分類、ターゲット検出、セマンティック セグメンテーション)の再現に焦点を当てており、初心者がそれらを使用できるようにします (簡単なものから深いものまで)。

コードはすべてエラーなしで実行されます! !

まず、深層学習の古典的な分類ネットワーク モジュールを再現します。その中で、バックボーン (10.、11.) はターゲット検出に特化していますが、その主な目的は特徴を抽出することであるため、以下を含めてここにも配置されます。

1.LeNet5(√)

2.VGG(√)

3.アレックスネット(√)

4.レスネット(√)

5.レスネクスト

6.グーグルネット

7.モバイルネット

8.シャッフルネット

9.EfficientNet

10.VovNet

11.ダークネット

...

知らせ:

a) 完全なコードが私の github にアップロードされます

https://github.com/HanXiaoyiGitHub/Simple-CV-Pytorch-master CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (1)https://github.com/HanXiaoyiGitHub/Simple-CV-Pytorch-master b) コンパイル環境が設定されている (実際には、このコンパイル環境を使用しない場合は、調整します バグも大丈夫です!)

python == 3.9.12torch == 1.11.0+cu113torchvision== 0.11.0+cu113torchaudio== 0.12.0+cu113pycocotools == 2.0.4numpyCythonmatplotlibopencv-pythonskimagetensorobardtqdmthop

c) 分類データ セットは、ImageNet または CIFAR10 とそのディレクトリを使用します (coco と voc はターゲット検出に使用され、セマンティック セグメンテーションは現在使用されていません)。

dataset path: /data/data||----coco----|----coco2017||----cifar||----ImageNet----|----ILSVRC2012||----VOCdevkitcoco2017 path: /data/coco/coco2017coco2017|||----annotations|----train2017|----test2017|----val2017voc path: /data/VOCdevkit|| |----Annotations| |----ImageSets|----VOC2007----|----JPEGImages| |----SegmentationClass| |----SegmentationObject||| |----Annotations| |----ImageSets|----VOC2012----|----JPEGImages| |----SegmentationClass| |----SegmentationObjectILSVRC2012 path : /data/ImageNet/ILSVRC2012||----train||----valcifar path: /data/cifar||----cifar-10-batches-py||----cifar-10-python.tar.gz

d) amp 混合精度を使用して GPU を高速化する. 使用方法がわからない場合は、次のリンクを参照してください。

Pytorch を使用してネットワーク モデルのトレーニングを高速化する方法は? (自動キャストと GradScaler) CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (2)https://blog.csdn.net/XiaoyYidiaodiao/article/details/124854343?spm=1001.2014.3001.5502

そのため、ネットワークモデルの forward 関数の前に @autocast() を追加する必要があります。また、torch バージョン 1.4 以降を使用しているため、ReLu(inplace=False)、Dropout(inplace=False) などを変更して設定する必要があります。 False に置き換えます。

e) LeNet5、VGG16、および AlexNet は全結合層を使用し、画像サイズを変更できないため、これらのネットワーク アーキテクチャの画像サイズは、画像の前処理中に固定する必要があります。

f) プロジェクトのファイル構造

使用しているOS(Ubuntu 20.04)はもちろんwindowsでも動くので、動かしてみました。一部のフォルダは必要ありません。そのままにしておいてください。後で説明します。

project path: /data/PycharmProject/Simple-CV-master path: /data/PycharmProject/Simple-CV-Pytorch-master||----checkpoints ( resnet50-19c8e357.pth \COCO_ResNet50.pth[RetinaNet]\ VOC_ResNet50.pth[RetinaNet] )|| |----cifar.py ( null, I just use torchvision.datasets.ImageFolder )| |----CIAR_labels.txt| |----coco.py| |----coco_eval.py| |----coco_labels.txt|----data----|----__init__.py| |----config.py ( path )| |----imagenet.py ( null, I just use torchvision.datasets.ImageFolder )| |----ImageNet_labels.txt| |----voc0712.py| |----voc_eval.py| |----voc_labels.txt| |----crash_helmet.jpg|----images----|----classification----|----sunflower.jpg| | |----photocopier.jpg| | |----automobile.jpg| || |----detection----|----000001.jpg| |----000001.xml| |----000002.jpg| |----000002.xml| |----000003.jpg| |----000003.xml||----log(XXX[ detection or classification ]_XXX[ train or test or eval ].info.log)|| |----__init__.py| || | |----__init.py| |----anchor----|----RetinaNetAnchors.py| || | |----lenet5.py| | |----alexnet.py| |----basenet----|----vgg.py| | |----resnet.py| || | |----DarkNetBackbone.py| |----backbones----|----__init__.py ( Don't finish writing )| | |----ResNetBackbone.py| | |----VovNetBackbone.py| || || ||----models----|----heads----|----__init.py| | |----RetinaNetHeads.py| || | |----RetinaNetLoss.py| |----losses----|----__init.py| || | |----FPN.py| |----necks----|----__init__.py| | |-----FPN.txt| || |----RetinaNet.py||----results ( eg: detection ( VOC or COCO AP ) )||----tensorboard ( Loss visualization )||----tools |----eval.py| |----classification----|----train.py| | |----test.py| || || || | |----eval_coco.py| | |----eval_voc.py| |----detection----|----test.py| |----train.py||| |----AverageMeter.py| |----BBoxTransform.py| |----ClipBoxes.py| |----Sampler.py| |----iou.py|----utils----|----__init__.py| |----accuracy.py| |----augmentations.py| |----collate.py| |----get_logger.py| |----nms.py| |----path.py||----FolderOrganization.txt||----main.py||----README.md||----requirements.txt

1.LeNet5(サイズ:32×32×3)[1]

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (3)

図1。

図 1. 復元コード

nn.BatchNorm2d() を追加して精度を上げます. もちろん、完全に再現するには、nn.BatchNorm2d() を無視してコードから削除できます.

最後の接続層の出力は、データセットのカテゴリに従って調整できます

from torch import nnfrom torch.cuda.amp import autocastclass lenet5(nn.Module): # cifar: 10, ImageNet: 1000 def __init__(self, num_classes=1000, init_weights=False): super(lenet5, self).__init__() self.num_classes = num_classes self.layers = nn.Sequential( # input:32 * 32 * 3 -> 28 * 28 * 6 nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, padding=0, stride=1, bias=False), nn.BatchNorm2d(6), nn.ReLU(), # 28 * 28 * 6 -> 14 * 14 * 6 nn.MaxPool2d(kernel_size=2, stride=2, padding=0), # 14 * 14 * 6 -> 10 * 10 * 16 nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, padding=0, stride=1, bias=False), nn.BatchNorm2d(16), nn.ReLU(), # 10 * 10 * 16 -> 5 * 5 * 16 nn.MaxPool2d(kernel_size=2, stride=2, padding=0), nn.Flatten(), nn.Linear(16 * 5 * 5, 120), nn.Linear(120, 84)) self.classifier = nn.Linear(84, self.num_classes) if init_weights: self._initialize_weights() @autocast() def forward(self, x): x = self.layers(x) x = self.classifier(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0)

2.AlexNet (サイズ: 224 * 224 * 3)[2]

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (4)

図2。

図 2 のように、特に明確でない場合は、以下の図 3 を参照してください。

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (5)

画像3。

図 3 を図 4 に変えてください。これは、以前の AlexNet が 2 枚のグラフィックス カードで実行されていたため (今年の計算能力では十分ではありませんでした)、今では計算能力が維持され、GPU 実行に配置できるようになったためです。

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (6)

図 4。

最後の接続層の出力は、データセットのカテゴリに従って調整できます

nn.BatchNorm2d() を追加して精度を上げます. もちろん、完全に再現するには、nn.BatchNorm2d() を無視してコードから削除できます.

import torch.nn as nnfrom torch.cuda.amp import autocastclass alexnet(nn.Module): def __init__(self, num_classes=1000, init_weights=False): super(alexnet, self).__init__() self.layers = nn.Sequential( # input: 224 * 224 * 3 -> 55 * 55 * (48*2) nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=2, bias=False), nn.BatchNorm2d(96), nn.ReLU(), # 55 * 55 * (48*2) -> 27 * 27 * (48*2) nn.MaxPool2d(kernel_size=3, stride=2), # 27 * 27 * (48*2) -> 27 * 27 * (128*2) nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, padding=2, bias=False), nn.BatchNorm2d(256), nn.ReLU(), # 27 * 27 * (128*2) -> 13 * 13 * (128*2) nn.MaxPool2d(kernel_size=3, stride=2), # 13 * 13 * (128*2) -> 13 * 13 * (192*2) nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(384), nn.ReLU(), # 13 * 13 * (192*2) -> 13 * 13 * (192*2) nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(384), nn.ReLU(), # 13 * 13 * (192*2) -> 13 * 13 * (128*2) nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(), # 13 * 13 * (128*2) -> 6 * 6 * (128*2) nn.MaxPool2d(kernel_size=3, stride=2) ) self.fc = nn.Sequential( nn.Flatten(), nn.Dropout(0.5), nn.Linear(6 * 6 * 128 * 2, 2048), nn.ReLU(), nn.Dropout(0.5), nn.Linear(2048, 2048), nn.ReLU() ) self.classifier = nn.Linear(2048, num_classes) if init_weights: self._initialize_weights() @autocast() def forward(self, x): x = self.layers(x) x = self.fc(x) x = self.classifier(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0)

3.VGG(サイズ:224×224 × 3)[3]

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (7)

図 5.

図 5. 緑の枠で囲まれたネットワーク アーキテクチャを再現し、コードを復元する

精度の悪さが堪らないので、nn.BatchNorm2d(i)を追加して移行学習を行いました。

最後の接続層の出力は、データセットのカテゴリに従って調整できます

import torchfrom torch import nnfrom utils.path import CheckPointsfrom torch.cuda.amp import autocast__all__ = [ 'vgg11', 'vgg13', 'vgg16', 'vgg19',]# if your network is limited, you can download them, and put them into CheckPoints(my Project:Simple-CV-Pytorch-master/checkpoints/).model_urls = { # 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 'vgg11': '{}/vgg11-bbd30ac9.pth'.format(CheckPoints), # 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 'vgg13': '{}/vgg13-c768596a.pth'.format(CheckPoints), # 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 'vgg16': '{}/vgg16-397923af.pth'.format(CheckPoints), # 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 'vgg19': '{}/vgg19-dcbb9e9d.pth'.format(CheckPoints)}def vgg_(arch, num_classes, pretrained, init_weights=False, **kwargs): cfg = cfgs["vgg" + arch] features = make_features(cfg) model = vgg(num_classes=num_classes, features=features, init_weights=init_weights, **kwargs) # if you're training for the first time, no pretrained is required! if pretrained: pretrained_models = torch.load(model_urls["vgg" + arch]) # transfer learning # if you want to train your own dataset if arch == '11': del pretrained_models['features.8.weight'] del pretrained_models['features.11.weight'] del pretrained_models['features.16.weight'] elif arch == '13': del pretrained_models['features.7.weight'] del pretrained_models['features.10.weight'] del pretrained_models['features.15.weight'] del pretrained_models['features.17.weight'] del pretrained_models['features.22.weight'] elif arch == '16': del pretrained_models['features.7.weight'] del pretrained_models['features.10.weight'] del pretrained_models['features.14.weight'] del pretrained_models['features.17.weight'] del pretrained_models['features.21.weight'] del pretrained_models['features.24.weight'] del pretrained_models['features.28.weight'] elif arch == '19': del pretrained_models['features.7.weight'] del pretrained_models['features.10.weight'] del pretrained_models['features.14.weight'] del pretrained_models['features.21.weight'] del pretrained_models['features.23.weight'] del pretrained_models['features.28.weight'] del pretrained_models['features.34.weight'] else: raise ValueError("Pretrained: unsupported VGG depth") model.load_state_dict(pretrained_models, strict=False) return modeldef vgg11(num_classes, pretrained=False, init_weights=False, **kwargs): return vgg_('11', num_classes, pretrained, init_weights, **kwargs)def vgg13(num_classes, pretrained=False, init_weights=False, **kwargs): return vgg_('13', num_classes, pretrained, init_weights, **kwargs)def vgg16(num_classes, pretrained=False, init_weights=False, **kwargs): return vgg_('16', num_classes, pretrained, init_weights, **kwargs)def vgg19(num_classes, pretrained=False, init_weights=False, **kwargs): return vgg_('19', num_classes, pretrained, init_weights, **kwargs)class vgg(nn.Module): # cifar: 10, ImageNet: 1000 def __init__(self, features, num_classes=1000, init_weights=False): super(vgg, self).__init__() self.num_classes = num_classes self.features = features self.fc = nn.Sequential( nn.Flatten(), nn.Linear(7 * 7 * 512, 4096), nn.ReLU(), nn.Dropout(0.5), nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5), ) self.classifier = nn.Linear(4096, self.num_classes) if init_weights: self._initialize_weights() @autocast() def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.fc(x) x = self.classifier(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0)def make_features(cfgs: list): layers = [] in_channels = 3 for i in cfgs: if i == "M": layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels, i, kernel_size=3, stride=1, padding=1, bias=False) layers += [conv2d, nn.BatchNorm2d(i), nn.ReLU()] in_channels = i return nn.Sequential(*layers)cfgs = { 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],}

4.レスネット[4]

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (8)

図 6.

図 6. ネットワーク アーキテクチャ (ResNet18、ResNet34、ResNet50、ResNet101、ResNet152) の再現、コードの復元

まずは各ブロックの再現方法を見てみましょうか。図 18 層、34 層は下の図 7 の緑のボックスで示され、50 層、101 層、152 層は下の図 7 の青のボックスで示されます。

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (9)

ブロック:18層、34層

# 18-layer, 34-layerclass BasicBlock(nn.Module): expansion = 1 def __init__(self, in_channels, out_channels, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.downsample = downsample def forward(self, x): identity = x if self.downsample is not None: identity = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += identity out = self.relu(out) return out

ブロック:50層、101層、152層

# 50-layer, 101-layer, 152-layerclass Bottleneck(nn.Module): """ self.conv1(kernel_size=1,stride=2) self.conv2(kernel_size=3,stride=1) to self.conv1(kernel_size=1,stride=1) self.conv2(kernel_size=3,stride=2) """ expansion = 4 def __init__(self, in_channels, out_channels, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=stride, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.conv3 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels * self.expansion, kernel__size=1, stride=1, bias=False) self.bn3 = nn.BatchNorm2d(out_channels * self.expansion) self.relu = nn.ReLU() self.downsample = downsample def forward(self, x): identity = x if self.downsample is not None: identity = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) out += identity out = self.relu(out) return out

ResNet モデル全体を復元するには、最初に畳み込みの最初のレイヤーと最大のプーリング レイヤーを復元します。

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (10)

class ResNet(nn.Module): def __init__(self, block, blocks_num, num_classes=1000, include_top=True): super(ResNet, self).__init__() self.include_top = include_top self.in_channels = 64 self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.in_channels, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(self.in_channels) self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

後続の層層を図 8 に示します。図のコードと表現モジュールの間の対等関係

conv2_x -> self.layer1, conv3_x -> self.layer2, conv4_x -> self.layer3, conv5_x -> self.layer4

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (11)

... self.layer1 = self._make_layer(block, 64, blocks_num[0]) self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2) self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2) self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)...

50層、101層、152層は点線を再現、18層、34層も同様なので表示されません。

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (12)

図 8.

 def _make_layer(self, block, channels, block_num, stride=1): downsample = None if stride != 1 or self.in_channels != channels * block.expansion: downsample = nn.Sequential( nn.Conv2d(in_channels=self.in_channels, out_channels=channels * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(channels * block.expansion) ) ...

次に、ResNet モデルを呼び出し、適切なレイヤー (18、34、50、101、152) を選択します。

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (13)

def resnet18(num_classes=1000, pretrained=False, include_top=True): return resnet_('18', BasicBlock, [2, 2, 2, 2], num_classes, pretrained, include_top)def resnet34(num_classes=1000, pretrained=False, include_top=True): return resnet_('34', BasicBlock, [3, 4, 6, 3], num_classes, pretrained, include_top)def resnet50(num_classes=1000, pretrained=False, include_top=True): return resnet_('50', Bottleneck, [3, 4, 6, 3], num_classes, pretrained, include_top)def resnet101(num_classes=1000, pretrained=False, include_top=True): return resnet_('101', Bottleneck, [3, 4, 23, 3], num_classes, pretrained, include_top)def resnet152(num_classes=1000, pretrained=False, include_top=True): return resnet_('152', Bottleneck, [3, 8, 36, 3], num_classes, pretrained, include_top)

完全なコード

import torchimport torch.nn as nnfrom utils.path import CheckPointsfrom torch.cuda.amp import autocast__all__ = [ 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']# if your network is limited, you can download them, and put them into CheckPoints(my Project:Simple-CV-Pytorch-master/checkpoints/).model_urls = { # 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet18': '{}/resnet18-5c106cde.pth'.format(CheckPoints), # 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet34': '{}/resnet34-333f7ec4.pth'.format(CheckPoints), # 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet50': '{}/resnet50-19c8e357.pth'.format(CheckPoints), # 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet101': '{}/resnet101-5d3b4d8f.pth'.format(CheckPoints), # 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 'resnet152': '{}/resnet152-b121ed2d.pth'.format(CheckPoints)}def resnet_(arch, block, block_num, num_classes, pretrained, include_top, **kwargs): model = resnet(block=block, blocks_num=block_num, num_classes=num_classes, include_top=include_top, **kwargs) # if you're training for the first time, no pretrained is required! if pretrained: # if you want to use cpu, you should modify map_loaction=torch.device("cpu") pretrained_models = torch.load(model_urls["resnet" + arch], map_location=torch.device("cuda:0")) # transfer learning # if you want to train your own dataset # del pretrained_models['module.classifier.bias'] model.load_state_dict(pretrained_models, strict=False) return model# 18-layer, 34-layerclass BasicBlock(nn.Module): expansion = 1 def __init__(self, in_channels, out_channels, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.downsample = downsample @autocast() def forward(self, x): identity = x if self.downsample is not None: identity = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += identity out = self.relu(out) return out# 50-layer, 101-layer, 152-layerclass Bottleneck(nn.Module): """ self.conv1(kernel_size=1,stride=2) self.conv2(kernel_size=3,stride=1) to self.conv1(kernel_size=1,stride=1) self.conv2(kernel_size=3,stride=2) acc: up 0.5% """ expansion = 4 def __init__(self, in_channels, out_channels, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=stride, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.conv3 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels * self.expansion, kernel__size=1, stride=1, bias=False) self.bn3 = nn.BatchNorm2d(out_channels * self.expansion) self.relu = nn.ReLU() self.downsample = downsample @autocast() def forward(self, x): identity = x if self.downsample is not None: identity = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) out += identity out = self.relu(out) return outclass resnet(nn.Module): def __init__(self, block, blocks_num, num_classes=1000, include_top=True): super(resnet, self).__init__() self.include_top = include_top self.in_channels = 64 self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.in_channels, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(self.in_channels) self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, blocks_num[0]) self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2) self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2) self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2) if self.include_top: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.flatten = nn.Flatten() self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') def _make_layer(self, block, channels, block_num, stride=1): downsample = None if stride != 1 or self.in_channels != channels * block.expansion: downsample = nn.Sequential( nn.Conv2d(in_channels=self.in_channels, out_channels=channels * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(channels * block.expansion) ) layers = [] layers.append(block(in_channels=self.in_channels, out_channels=channels, downsample=downsample, stride=stride)) self.in_channels = channels * block.expansion for _ in range(1, block_num): layers.append( block(in_channels=self.in_channels, out_channels=channels)) return nn.Sequential(*layers) @autocast() def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) if self.include_top: x = self.avgpool(x) x = self.flatten(x) x = self.fc(x) return xdef resnet18(num_classes=1000, pretrained=False, include_top=True): return resnet_('18', BasicBlock, [2, 2, 2, 2], num_classes, pretrained, include_top)def resnet34(num_classes=1000, pretrained=False, include_top=True): return resnet_('34', BasicBlock, [3, 4, 6, 3], num_classes, pretrained, include_top)def resnet50(num_classes=1000, pretrained=False, include_top=True): return resnet_('50', Bottleneck, [3, 4, 6, 3], num_classes, pretrained, include_top)def resnet101(num_classes=1000, pretrained=False, include_top=True): return resnet_('101', Bottleneck, [3, 4, 23, 3], num_classes, pretrained, include_top)def resnet152(num_classes=1000, pretrained=False, include_top=True): return resnet_('152', Bottleneck, [3, 8, 36, 3], num_classes, pretrained, include_top)

いくつかの設定ファイル

ユーティリティ/path.py

import os.pathimport sysBASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))sys.path.append(BASE_DIR)# Gets home dir cross platform# "/data/"MyName = "PycharmProject"Folder = "Simple-CV-Pytorch-master"# Path to store checkpoint modelCheckPoints = 'checkpoints'CheckPoints = os.path.join(BASE_DIR, MyName, Folder, CheckPoints)# Path to store tensorboard loadtensorboard_log = 'tensorboard'tensorboard_log = os.path.join(BASE_DIR, MyName, Folder, tensorboard_log)# Path to save loglog = 'log'log = os.path.join(BASE_DIR, MyName, Folder, log)# Path to save classification train logclassification_train_log = 'classification_train'# Path to save classification test logclassification_test_log = 'classification_test'# Path to save classification eval logclassification_eval_log = 'classification_eval'# Classification evaluate model pathclassification_evaluate = None# Images classification pathimage_cls = 'automobile.jpg'images_cls_path = 'images/classification'images_cls_path = os.path.join(BASE_DIR, MyName, Folder, images_cls_path, image_cls)# DataDATAPATH = BASE_DIR# ImageNet/ILSVRC2012ImageNet = "ImageNet/ILSVRC2012"ImageNet_Train_path = os.path.join(DATAPATH, ImageNet, 'train')ImageNet_Eval_path = os.path.join(DATAPATH, ImageNet, 'val')# CIFAR10CIFAR = 'cifar'CIFAR_path = os.path.join(DATAPATH, CIFAR)

データ/config.py

from utils import path# Path to save loglog = path.log# Path to save classification train logclassification_train_log = path.classification_train_log# Path to save classification test logclassification_test_log = path.classification_test_log# Path to save classification eval logclassification_eval_log = path.classification_eval_log# Path to store checkpoint modelcheckpoint_path = path.CheckPoints# Classification evaluate model pathclassification_evaluate = path.classification_evaluate# Classification test imagesimages_cls_root = path.images_cls_path# Path to save tensorboardtensorboard_log = path.tensorboard_log

トレーニングコード

ツール/分類/train.py

import osimport loggingimport argparseimport warningswarnings.filterwarnings('ignore')import sysBASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))sys.path.append(BASE_DIR)import timeimport torchfrom data import *import torchvisionimport torch.nn as nnimport torch.nn.parallelimport torch.optim as optimfrom torchvision import transformsfrom utils.accuracy import accuracyfrom torch.utils.data import DataLoaderfrom utils.get_logger import get_loggerfrom models.basenets.lenet5 import lenet5from models.basenets.alexnet import alexnetfrom utils.AverageMeter import AverageMeterfrom torch.cuda.amp import autocast, GradScalerfrom models.basenets.vgg import vgg11, vgg13, vgg16, vgg19from models.basenets.resnet import resnet18, resnet34, resnet50, resnet101, resnet152def parse_args(): parser = argparse.ArgumentParser(description='PyTorch Classification Training') parser.add_mutually_exclusive_group() parser.add_argument('--dataset', type=str, default='CIFAR', choices=['ImageNet', 'CIFAR'], help='ImageNet, CIFAR') parser.add_argument('--dataset_root', type=str, default=CIFAR_ROOT, choices=[ImageNet_Train_ROOT, CIFAR_ROOT], help='Dataset root directory path') parser.add_argument('--basenet', type=str, default='lenet', choices=['resnet', 'vgg', 'lenet', 'alexnet'], help='Pretrained base model') parser.add_argument('--depth', type=int, default=5, help='BaseNet depth, including: LeNet of 5, AlexNet of 0, VGG of 11, 13, 16, 19, ResNet of 18, 34, 50, 101, 152') parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training') parser.add_argument('--resume', type=str, default=None, help='Checkpoint state_dict file to resume training from') parser.add_argument('--num_workers', type=int, default=8, help='Number of workers user in dataloading') parser.add_argument('--cuda', type=str, default=True, help='Use CUDA to train model') parser.add_argument('--momentum', type=float, default=0.9, help='Momentum value for optim') parser.add_argument('--gamma', type=float, default=0.1, help='Gamma update for SGD') parser.add_argument('--accumulation_steps', type=int, default=1, help='Gradient acumulation steps') parser.add_argument('--save_folder', type=str, default=config.checkpoint_path, help='Directory for saving checkpoint models') parser.add_argument('--tensorboard', type=str, default=False, help='Use tensorboard for loss visualization') parser.add_argument('--log_folder', type=str, default=config.log, help='Log Folder') parser.add_argument('--log_name', type=str, default=config.classification_train_log, help='Log Name') parser.add_argument('--tensorboard_log', type=str, default=config.tensorboard_log, help='Use tensorboard for loss visualization') parser.add_argument('--lr', type=float, default=1e-2, help='learning rate') parser.add_argument('--epochs', type=int, default=30, help='Number of epochs') parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay') parser.add_argument('--milestones', type=list, default=[15, 20, 30], help='Milestones') parser.add_argument('--num_classes', type=int, default=10, help='the number classes, like ImageNet:1000, cifar:10') parser.add_argument('--image_size', type=int, default=32, help='image size, like ImageNet:224, cifar:32') parser.add_argument('--pretrained', type=str, default=True, help='Models was pretrained') parser.add_argument('--init_weights', type=str, default=False, help='Init Weights') return parser.parse_args()args = parse_args()# 1. Logget_logger(args.log_folder, args.log_name)logger = logging.getLogger(args.log_name)# 2. Torch choose cuda or cpuif torch.cuda.is_available(): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') if not args.cuda: print("WARNING: It looks like you have a CUDA device, but you aren't using it" + "\n You can set the parameter of cuda to True.") torch.set_default_tensor_type('torch.FloatTensor')else: torch.set_default_tensor_type('torch.FloatTensor')if not os.path.exists(args.save_folder): os.mkdir(args.save_folder)def train(): # 3. Create SummaryWriter if args.tensorboard: from torch.utils.tensorboard import SummaryWriter # tensorboard loss writer = SummaryWriter(args.tensorboard_log) # vgg16, alexnet and lenet5 need to resize image_size, because of fc. if args.basenet == 'vgg' or args.basenet == 'alexnet': args.image_size = 224 elif args.basenet == 'lenet': args.image_size = 32 # 4. Ready dataset if args.dataset == 'ImageNet': if args.dataset_root == CIFAR_ROOT: raise ValueError('Must specify dataset_root if specifying dataset ImageNet2012.') elif os.path.exists(ImageNet_Train_ROOT) is None: raise ValueError("WARNING: Using default ImageNet2012 dataset_root because " + "--dataset_root was not specified.") dataset = torchvision.datasets.ImageFolder( root=args.dataset_root, transform=torchvision.transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])) elif args.dataset == 'CIFAR': if args.dataset_root == ImageNet_Train_ROOT: raise ValueError('Must specify dataset_root if specifying dataset CIFAR10.') elif args.dataset_root is None: raise ValueError("Must provide --dataset_root when training on CIFAR10.") dataset = torchvision.datasets.CIFAR10(root=args.dataset_root, train=True, transform=torchvision.transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), torchvision.transforms.ToTensor()])) else: raise ValueError('Dataset type not understood (must be ImageNet or CIFAR), exiting.') dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=False, generator=torch.Generator(device='cuda')) top1 = AverageMeter() top5 = AverageMeter() losses = AverageMeter() # 5. Define train model # Unfortunately, Lenet5 and Alexnet don't provide pretrianed Model. if args.basenet == 'lenet': if args.depth == 5: model = lenet5(num_classes=args.num_classes, init_weights=args.init_weights) else: raise ValueError('Unsupported LeNet depth!') elif args.basenet == 'alexnet': model = alexnet(num_classes=args.num_classes, init_weights=args.init_weights) elif args.basenet == 'vgg': if args.depth == 11: model = vgg11(pretrained=args.pretrained, num_classes=args.num_classes, init_weights=args.init_weights) elif args.depth == 13: model = vgg13(pretrained=args.pretrained, num_classes=args.num_classes, init_weights=args.init_weights) elif args.depth == 16: model = vgg16(pretrained=args.pretrained, num_classes=args.num_classes, init_weights=args.init_weights) elif args.depth == 19: model = vgg19(pretrained=args.pretrained, num_classes=args.num_classes, init_weights=args.init_weights) else: raise ValueError('Unsupported VGG depth!') # Unfortunately for my resnet, there is no set init_weight, because I'm going to set object detection algorithm elif args.basenet == 'resnet': if args.depth == 18: model = resnet18(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 34: model = resnet34(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 50: model = resnet50(pretrained=args.pretrained, num_classes=args.num_classes) # False means the models was not trained elif args.depth == 101: model = resnet101(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 152: model = resnet152(pretrained=args.pretrained, num_classes=args.num_classes) else: raise ValueError('Unsupported ResNet depth!') else: raise ValueError('Unsupported model type!') if args.cuda: if torch.cuda.is_available(): model = model.cuda() model = torch.nn.DataParallel(model).cuda() else: model = torch.nn.DataParallel(model) # 6. Loading weights if args.resume: other, ext = os.path.splitext(args.resume) if ext == '.pkl' or '.pth': print('Loading weights into state dict...') model_load = os.path.join(args.save_folder, args.resume) model.load_state_dict(torch.load(model_load)) else: print('Sorry only .pth and .pkl files supported.') if args.init_weights: # initialize newly added models' weights with xavier method if args.basenet == 'resnet': print("There is no set init_weight, because I'm going to set object detection algorithm.") else: print("Initializing weights...") else: print("Not Initializing weights...") if args.pretrained: if args.basenet == 'lenet' or args.basenet == 'alexnet': print("There is no available pretrained model on the website. ") else: print("Models was pretrained...") else: print("Pretrained models is False...") model.train() iteration = 0 # 7. Optimizer optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) criterion = nn.CrossEntropyLoss() scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=args.milestones, gamma=args.gamma) scaler = GradScaler() # 8. Length iter_size = len(dataset) // args.batch_size print("len(dataset): {}, iter_size: {}".format(len(dataset), iter_size)) logger.info(f"args - {args}") t0 = time.time() # 9. Create batch iterator for epoch in range(args.epochs): t1 = time.time() torch.cuda.empty_cache() # 10. Load train data for data in dataloader: iteration += 1 images, targets = data # 11. Backward optimizer.zero_grad() if args.cuda: images, targets = images.cuda(), targets.cuda() criterion = criterion.cuda() # 12. Forward with autocast(): outputs = model(images) loss = criterion(outputs, targets) loss = loss / args.accumulation_steps if args.tensorboard: writer.add_scalar("train_classification_loss", loss.item(), iteration) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() # 13. Measure accuracy and record loss acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) top1.update(acc1.item(), images.size(0)) top5.update(acc5.item(), images.size(0)) losses.update(loss.item(), images.size(0)) if iteration % 100 == 0: logger.info( f"- epoch: {epoch}, iteration: {iteration}, lr: {optimizer.param_groups[0]['lr']}, " f"top1 acc: {acc1.item():.2f}%, top5 acc: {acc5.item():.2f}%, " f"loss: {loss.item():.3f}, (losses.avg): {losses.avg:3f} " ) scheduler.step(losses.avg) t2 = time.time() h_time = (t2 - t1) // 3600 m_time = ((t2 - t1) % 3600) // 60 s_time = ((t2 - t1) % 3600) % 60 print("epoch {} is finished, and the time is {}h{}min{}s".format(epoch, int(h_time), int(m_time), int(s_time))) # 14. Save train model if epoch != 0 and epoch % 10 == 0: print('Saving state, iter:', epoch) torch.save(model.state_dict(), args.save_folder + '/' + args.dataset + '_' + args.basenet + str(args.depth) + '_' + repr(epoch) + '.pth') torch.save(model.state_dict(), args.save_folder + '/' + args.dataset + "_" + args.basenet + str(args.depth) + '.pth') if args.tensorboard: writer.close() t3 = time.time() h = (t3 - t0) // 3600 m = ((t3 - t0) % 3600) // 60 s = ((t3 - t0) % 3600) % 60 print("The Finished Time is {}h{}m{}s".format(int(h), int(m), int(s))) return top1.avg, top5.avg, losses.avgif __name__ == '__main__': torch.multiprocessing.set_start_method('spawn') logger.info("Program started") top1, top5, loss = train() print("top1 acc: {}, top5 acc: {}, loss:{}".format(top1, top5, loss)) logger.info("Done!")

テストコード

ツール/分類/test.py

import loggingimport osimport argparseimport warningswarnings.filterwarnings('ignore')import sysBASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))sys.path.append(BASE_DIR)import timefrom data import *from PIL import Imageimport torch.nn.parallelfrom torchvision import transformsfrom utils.get_logger import get_loggerfrom models.basenets.lenet5 import lenet5from models.basenets.alexnet import alexnetfrom models.basenets.vgg import vgg11, vgg13, vgg16, vgg19from models.basenets.resnet import resnet18, resnet34, resnet50, resnet101, resnet152def parse_args(): parser = argparse.ArgumentParser(description='PyTorch Classification Testing') parser.add_mutually_exclusive_group() parser.add_argument('--dataset', type=str, default='CIFAR', choices=['ImageNet', 'CIFAR'], help='ImageNet, CIFAR') parser.add_argument('--images_root', type=str, default=config.images_cls_root, help='Dataset root directory path') parser.add_argument('--basenet', type=str, default='alexnet', choices=['resnet', 'vgg', 'lenet', 'alexnet'], help='Pretrained base model') parser.add_argument('--depth', type=int, default=0, help='BaseNet depth, including: LeNet of 5, AlexNet of 0, VGG of 11, 13, 16, 19, ResNet of 18, 34, 50, 101, 152') parser.add_argument('--evaluate', type=str, default=config.classification_evaluate, help='Checkpoint state_dict file to evaluate training from') parser.add_argument('--save_folder', type=str, default=config.checkpoint_path, help='Directory for saving checkpoint models') parser.add_argument('--log_folder', type=str, default=config.log, help='Log Folder') parser.add_argument('--log_name', type=str, default=config.classification_test_log, help='Log Name') parser.add_argument('--cuda', type=str, default=True, help='Use CUDA to train model') parser.add_argument('--num_classes', type=int, default=10, help='the number classes, like ImageNet:1000, cifar:10') parser.add_argument('--image_size', type=int, default=32, help='image size, like ImageNet:224, cifar:32') parser.add_argument('--pretrained', type=str, default=False, help='Models was pretrained') return parser.parse_args()args = parse_args()# 1. Torch choose cuda or cpuif torch.cuda.is_available(): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') if not args.cuda: print("WARNING: It looks like you have a CUDA device, but you aren't using it" + "\n You can set the parameter of cuda to True.") torch.set_default_tensor_type('torch.FloatTensor')else: torch.set_default_tensor_type('torch.FloatTensor')if not os.path.exists(args.save_folder): os.mkdir(args.save_folder)# 2. Logget_logger(args.log_folder, args.log_name)logger = logging.getLogger(args.log_name)def get_label_file(filename): if not os.path.exists(filename): print("The dataset label.txt is empty, We need to create a new one.") os.mkdir(filename) return filenamedef dataset_labels_results(filename, output): filename = os.path.join(BASE_DIR, 'data', filename + '_labels.txt') get_label_file(filename=filename) with open(file=filename, mode='r') as f: dict = f.readlines() output = output.cpu().numpy() output = output[0] output = dict[output] f.close() return outputdef test(): # vgg16, alexnet and lenet5 need to resize image_size, because of fc. if args.basenet == 'vgg' or args.basenet == 'alexnet': args.image_size = 224 elif args.basenet == 'lenet': args.image_size = 32 # 3. Ready image if args.images_root is None: raise ValueError("The images is None, you should load image!") image = Image.open(args.images_root) transform = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor()]) image = transform(image) image = image.reshape(1, 3, args.image_size, args.image_size) # 4. Define to train mode if args.basenet == 'lenet': if args.depth == 5: model = lenet5(num_classes=args.num_classes) else: raise ValueError('Unsupported LeNet depth!') elif args.basenet == 'alexnet': model = alexnet(num_classes=args.num_classes) elif args.basenet == 'vgg': if args.depth == 11: model = vgg11(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 13: model = vgg13(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 16: model = vgg16(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 19: model = vgg19(pretrained=args.pretrained, num_classes=args.num_classes) else: raise ValueError('Unsupported VGG depth!') elif args.basenet == 'resnet': if args.depth == 18: model = resnet18(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 34: model = resnet34(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 50: model = resnet50(pretrained=args.pretrained, num_classes=args.num_classes) # False means the models is not trained elif args.depth == 101: model = resnet101(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 152: model = resnet152(pretrained=args.pretrained, num_classes=args.num_classes) else: raise ValueError('Unsupported ResNet depth!') else: raise ValueError('Unsupported model type!') if args.cuda: model = model.cuda() model = torch.nn.DataParallel(model).cuda() else: model = torch.nn.DataParallel(model) # 5. Loading model if args.evaluate: other, ext = os.path.splitext(args.evaluate) if ext == '.pkl' or '.pth': print('Loading weights into state dict...') model_evaluate_load = os.path.join(args.save_folder, args.evaluate) model.load_state_dict(torch.load(model_evaluate_load)) else: print('Sorry only .pth and .pkl files supported.') elif args.evaluate is None: print("Sorry, you should load weights! ") model.eval() # 6. print logger.info(f"args - {args}") # 7. Test with torch.no_grad(): t0 = time.time() # 8. Forward if args.cuda: image = image.cuda() output = model(image) output = output.argmax(1) t1 = time.time() m = (t1 - t0) // 60 s = (t1 - t0) % 60 folder_name = args.dataset output = dataset_labels_results(filename=folder_name, output=output) logger.info(f"output: {output}") print("It took a total of {}m{}s to complete the testing.".format(int(m), int(s))) return outputif __name__ == '__main__': torch.multiprocessing.set_start_method('spawn') logger.info("Program started") output = test() logger.info("Done!")

ラベル

CIFAR_label.txt

{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}

ImageNet_label.txt

{0: 'tench, Tinca tinca', 1: 'goldfish, Carassius auratus', 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', 3: 'tiger shark, Galeocerdo cuvieri', 4: 'hammerhead, hammerhead shark', 5: 'electric ray, crampfish, numbfish, torpedo', 6: 'stingray', 7: 'co*ck', 8: 'hen', 9: 'ostrich, Struthio camelus', 10: 'brambling, Fringilla montifringilla', 11: 'goldfinch, Carduelis carduelis', 12: 'house finch, linnet, Carpodacus mexicanus', 13: 'junco, snowbird', 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea', 15: 'robin, American robin, Turdus migratorius', 16: 'bulbul', 17: 'jay', 18: 'magpie', 19: 'chickadee', 20: 'water ouzel, dipper', 21: 'kite', 22: 'bald eagle, American eagle, Haliaeetus leucocephalus', 23: 'vulture', 24: 'great grey owl, great gray owl, Strix nebulosa', 25: 'European fire salamander, Salamandra salamandra', 26: 'common newt, Triturus vulgaris', 27: 'eft', 28: 'spotted salamander, Ambystoma maculatum', 29: 'axolotl, mud puppy, Ambystoma mexicanum', 30: 'bullfrog, Rana catesbeiana', 31: 'tree frog, tree-frog', 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui', 33: 'loggerhead, loggerhead turtle, Caretta caretta', 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', 35: 'mud turtle', 36: 'terrapin', 37: 'box turtle, box tortoise', 38: 'banded gecko', 39: 'common iguana, iguana, Iguana iguana', 40: 'American chameleon, anole, Anolis carolinensis', 41: 'whiptail, whiptail lizard', 42: 'agama', 43: 'frilled lizard, Chlamydosaurus kingi', 44: 'alligator lizard', 45: 'Gila monster, Heloderma suspectum', 46: 'green lizard, Lacerta viridis', 47: 'African chameleon, Chamaeleo chamaeleon', 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', 49: 'African crocodile, Nile crocodile, Crocodylus niloticus', 50: 'American alligator, Alligator mississipiensis', 51: 'triceratops', 52: 'thunder snake, worm snake, Carphophis amoenus', 53: 'ringneck snake, ring-necked snake, ring snake', 54: 'hognose snake, puff adder, sand viper', 55: 'green snake, grass snake', 56: 'king snake, kingsnake', 57: 'garter snake, grass snake', 58: 'water snake', 59: 'vine snake', 60: 'night snake, Hypsiglena torquata', 61: 'boa constrictor, Constrictor constrictor', 62: 'rock python, rock snake, Python sebae', 63: 'Indian cobra, Naja naja', 64: 'green mamba', 65: 'sea snake', 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus', 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus', 68: 'sidewinder, horned rattlesnake, Crotalus cerastes', 69: 'trilobite', 70: 'harvestman, daddy longlegs, Phalangium opilio', 71: 'scorpion', 72: 'black and gold garden spider, Argiope aurantia', 73: 'barn spider, Araneus cavaticus', 74: 'garden spider, Aranea diademata', 75: 'black widow, Latrodectus mactans', 76: 'tarantula', 77: 'wolf spider, hunting spider', 78: 'tick', 79: 'centipede', 80: 'black grouse', 81: 'ptarmigan', 82: 'ruffed grouse, partridge, Bonasa umbellus', 83: 'prairie chicken, prairie grouse, prairie fowl', 84: 'peaco*ck', 85: 'quail', 86: 'partridge', 87: 'African grey, African gray, Psittacus erithacus', 88: 'macaw', 89: 'sulphur-crested co*ckatoo, Kakatoe galerita, Cacatua galerita', 90: 'lorikeet', 91: 'coucal', 92: 'bee eater', 93: 'hornbill', 94: 'hummingbird', 95: 'jacamar', 96: 'toucan', 97: 'drake', 98: 'red-breasted merganser, Mergus serrator', 99: 'goose', 100: 'black swan, Cygnus atratus', 101: 'tusker', 102: 'echidna, spiny anteater, anteater', 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', 104: 'wallaby, brush kangaroo', 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', 106: 'wombat', 107: 'jellyfish', 108: 'sea anemone, anemone', 109: 'brain coral', 110: 'flatworm, platyhelminth', 111: 'nematode, nematode worm, roundworm', 112: 'conch', 113: 'snail', 114: 'slug', 115: 'sea slug, nudibranch', 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore', 117: 'chambered nautilus, pearly nautilus, nautilus', 118: 'Dungeness crab, Cancer magister', 119: 'rock crab, Cancer irroratus', 120: 'fiddler crab', 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus', 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', 124: 'crayfish, crawfish, crawdad, crawdaddy', 125: 'hermit crab', 126: 'isopod', 127: 'white stork, Ciconia ciconia', 128: 'black stork, Ciconia nigra', 129: 'spoonbill', 130: 'flamingo', 131: 'little blue heron, Egretta caerulea', 132: 'American egret, great white heron, Egretta albus', 133: 'bittern', 134: 'crane', 135: 'limpkin, Aramus pictus', 136: 'European gallinule, Porphyrio porphyrio', 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana', 138: 'bustard', 139: 'ruddy turnstone, Arenaria interpres', 140: 'red-backed sandpiper, dunlin, Erolia alpina', 141: 'redshank, Tringa totanus', 142: 'dowitcher', 143: 'oystercatcher, oyster catcher', 144: 'pelican', 145: 'king penguin, Aptenodytes patagonica', 146: 'albatross, mollymawk', 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca', 149: 'dugong, Dugong dugon', 150: 'sea lion', 151: 'Chihuahua', 152: 'Japanese spaniel', 153: 'Maltese dog, Maltese terrier, Maltese', 154: 'Pekinese, Pekingese, Peke', 155: 'Shih-Tzu', 156: 'Blenheim spaniel', 157: 'papillon', 158: 'toy terrier', 159: 'Rhodesian ridgeback', 160: 'Afghan hound, Afghan', 161: 'basset, basset hound', 162: 'beagle', 163: 'bloodhound, sleuthhound', 164: 'bluetick', 165: 'black-and-tan coonhound', 166: 'Walker hound, Walker foxhound', 167: 'English foxhound', 168: 'redbone', 169: 'borzoi, Russian wolfhound', 170: 'Irish wolfhound', 171: 'Italian greyhound', 172: 'whippet', 173: 'Ibizan hound, Ibizan Podenco', 174: 'Norwegian elkhound, elkhound', 175: 'otterhound, otter hound', 176: 'Saluki, gazelle hound', 177: 'Scottish deerhound, deerhound', 178: 'Weimaraner', 179: 'Staffordshire bullterrier, Staffordshire bull terrier', 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', 181: 'Bedlington terrier', 182: 'Border terrier', 183: 'Kerry blue terrier', 184: 'Irish terrier', 185: 'Norfolk terrier', 186: 'Norwich terrier', 187: 'Yorkshire terrier', 188: 'wire-haired fox terrier', 189: 'Lakeland terrier', 190: 'Sealyham terrier, Sealyham', 191: 'Airedale, Airedale terrier', 192: 'cairn, cairn terrier', 193: 'Australian terrier', 194: 'Dandie Dinmont, Dandie Dinmont terrier', 195: 'Boston bull, Boston terrier', 196: 'miniature schnauzer', 197: 'giant schnauzer', 198: 'standard schnauzer', 199: 'Scotch terrier, Scottish terrier, Scottie', 200: 'Tibetan terrier, chrysanthemum dog', 201: 'silky terrier, Sydney silky', 202: 'soft-coated wheaten terrier', 203: 'West Highland white terrier', 204: 'Lhasa, Lhasa apso', 205: 'flat-coated retriever', 206: 'curly-coated retriever', 207: 'golden retriever', 208: 'Labrador retriever', 209: 'Chesapeake Bay retriever', 210: 'German short-haired pointer', 211: 'vizsla, Hungarian pointer', 212: 'English setter', 213: 'Irish setter, red setter', 214: 'Gordon setter', 215: 'Brittany spaniel', 216: 'clumber, clumber spaniel', 217: 'English springer, English springer spaniel', 218: 'Welsh springer spaniel', 219: 'co*cker spaniel, English co*cker spaniel, co*cker', 220: 'Sussex spaniel', 221: 'Irish water spaniel', 222: 'kuvasz', 223: 'schipperke', 224: 'groenendael', 225: 'malinois', 226: 'briard', 227: 'kelpie', 228: 'komondor', 229: 'Old English sheepdog, bobtail', 230: 'Shetland sheepdog, Shetland sheep dog, Shetland', 231: 'collie', 232: 'Border collie', 233: 'Bouvier des Flandres, Bouviers des Flandres', 234: 'Rottweiler', 235: 'German shepherd, German shepherd dog, German police dog, alsatian', 236: 'Doberman, Doberman pinscher', 237: 'miniature pinscher', 238: 'Greater Swiss Mountain dog', 239: 'Bernese mountain dog', 240: 'Appenzeller', 241: 'EntleBucher', 242: 'boxer', 243: 'bull mastiff', 244: 'Tibetan mastiff', 245: 'French bulldog', 246: 'Great Dane', 247: 'Saint Bernard, St Bernard', 248: 'Eskimo dog, husky', 249: 'malamute, malemute, Alaskan malamute', 250: 'Siberian husky', 251: 'dalmatian, coach dog, carriage dog', 252: 'affenpinscher, monkey pinscher, monkey dog', 253: 'basenji', 254: 'pug, pug-dog', 255: 'Leonberg', 256: 'Newfoundland, Newfoundland dog', 257: 'Great Pyrenees', 258: 'Samoyed, Samoyede', 259: 'Pomeranian', 260: 'chow, chow chow', 261: 'keeshond', 262: 'Brabancon griffon', 263: 'Pembroke, Pembroke Welsh corgi', 264: 'Cardigan, Cardigan Welsh corgi', 265: 'toy poodle', 266: 'miniature poodle', 267: 'standard poodle', 268: 'Mexican hairless', 269: 'timber wolf, grey wolf, gray wolf, Canis lupus', 270: 'white wolf, Arctic wolf, Canis lupus tundrarum', 271: 'red wolf, maned wolf, Canis rufus, Canis niger', 272: 'coyote, prairie wolf, brush wolf, Canis latrans', 273: 'dingo, warrigal, warragal, Canis dingo', 274: 'dhole, Cuon alpinus', 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', 276: 'hyena, hyaena', 277: 'red fox, Vulpes vulpes', 278: 'kit fox, Vulpes macrotis', 279: 'Arctic fox, white fox, Alopex lagopus', 280: 'grey fox, gray fox, Urocyon cinereoargenteus', 281: 'tabby, tabby cat', 282: 'tiger cat', 283: 'Persian cat', 284: 'Siamese cat, Siamese', 285: 'Egyptian cat', 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', 287: 'lynx, catamount', 288: 'leopard, Panthera pardus', 289: 'snow leopard, ounce, Panthera uncia', 290: 'jaguar, panther, Panthera onca, Felis onca', 291: 'lion, king of beasts, Panthera leo', 292: 'tiger, Panthera tigris', 293: 'cheetah, chetah, Acinonyx jubatus', 294: 'brown bear, bruin, Ursus arctos', 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus', 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus', 297: 'sloth bear, Melursus ursinus, Ursus ursinus', 298: 'mongoose', 299: 'meerkat, mierkat', 300: 'tiger beetle', 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle', 302: 'ground beetle, carabid beetle', 303: 'long-horned beetle, longicorn, longicorn beetle', 304: 'leaf beetle, chrysomelid', 305: 'dung beetle', 306: 'rhinoceros beetle', 307: 'weevil', 308: 'fly', 309: 'bee', 310: 'ant, emmet, pismire', 311: 'grasshopper, hopper', 312: 'cricket', 313: 'walking stick, walkingstick, stick insect', 314: 'co*ckroach, roach', 315: 'mantis, mantid', 316: 'cicada, cicala', 317: 'leafhopper', 318: 'lacewing, lacewing fly', 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", 320: 'damselfly', 321: 'admiral', 322: 'ringlet, ringlet butterfly', 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', 324: 'cabbage butterfly', 325: 'sulphur butterfly, sulfur butterfly', 326: 'lycaenid, lycaenid butterfly', 327: 'starfish, sea star', 328: 'sea urchin', 329: 'sea cucumber, holothurian', 330: 'wood rabbit, cottontail, cottontail rabbit', 331: 'hare', 332: 'Angora, Angora rabbit', 333: 'hamster', 334: 'porcupine, hedgehog', 335: 'fox squirrel, eastern fox squirrel, Sciurus niger', 336: 'marmot', 337: 'beaver', 338: 'guinea pig, Cavia cobaya', 339: 'sorrel', 340: 'zebra', 341: 'hog, pig, grunter, squealer, Sus scrofa', 342: 'wild boar, boar, Sus scrofa', 343: 'warthog', 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius', 345: 'ox', 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis', 347: 'bison', 348: 'ram, tup', 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', 350: 'ibex, Capra ibex', 351: 'hartebeest', 352: 'impala, Aepyceros melampus', 353: 'gazelle', 354: 'Arabian camel, dromedary, Camelus dromedarius', 355: 'llama', 356: 'weasel', 357: 'mink', 358: 'polecat, fitch, foulmart, foumart, Mustela putorius', 359: 'black-footed ferret, ferret, Mustela nigripes', 360: 'otter', 361: 'skunk, polecat, wood puss*', 362: 'badger', 363: 'armadillo', 364: 'three-toed sloth, ai, Bradypus tridactylus', 365: 'orangutan, orang, orangutang, Pongo pygmaeus', 366: 'gorilla, Gorilla gorilla', 367: 'chimpanzee, chimp, Pan troglodytes', 368: 'gibbon, Hylobates lar', 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus', 370: 'guenon, guenon monkey', 371: 'patas, hussar monkey, Erythrocebus patas', 372: 'baboon', 373: 'macaque', 374: 'langur', 375: 'colobus, colobus monkey', 376: 'proboscis monkey, Nasalis larvatus', 377: 'marmoset', 378: 'capuchin, ringtail, Cebus capucinus', 379: 'howler monkey, howler', 380: 'titi, titi monkey', 381: 'spider monkey, Ateles geoffroyi', 382: 'squirrel monkey, Saimiri sciureus', 383: 'Madagascar cat, ring-tailed lemur, Lemur catta', 384: 'indri, indris, Indri indri, Indri brevicaudatus', 385: 'Indian elephant, Elephas maximus', 386: 'African elephant, Loxodonta africana', 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens', 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca', 389: 'barracouta, snoek', 390: 'eel', 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', 392: 'rock beauty, Holocanthus tricolor', 393: 'anemone fish', 394: 'sturgeon', 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus', 396: 'lionfish', 397: 'puffer, pufferfish, blowfish, globefish', 398: 'abacus', 399: 'abaya', 400: "academic gown, academic robe, judge's robe", 401: 'accordion, piano accordion, squeeze box', 402: 'acoustic guitar', 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier', 404: 'airliner', 405: 'airship, dirigible', 406: 'altar', 407: 'ambulance', 408: 'amphibian, amphibious vehicle', 409: 'analog clock', 410: 'apiary, bee house', 411: 'apron', 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', 413: 'assault rifle, assault gun', 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack', 415: 'bakery, bakeshop, bakehouse', 416: 'balance beam, beam', 417: 'balloon', 418: 'ballpoint, ballpoint pen, ballpen, Biro', 419: 'Band Aid', 420: 'banjo', 421: 'bannister, banister, balustrade, balusters, handrail', 422: 'barbell', 423: 'barber chair', 424: 'barbershop', 425: 'barn', 426: 'barometer', 427: 'barrel, cask', 428: 'barrow, garden cart, lawn cart, wheelbarrow', 429: 'baseball', 430: 'basketball', 431: 'bassinet', 432: 'bassoon', 433: 'bathing cap, swimming cap', 434: 'bath towel', 435: 'bathtub, bathing tub, bath, tub', 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', 437: 'beacon, lighthouse, beacon light, pharos', 438: 'beaker', 439: 'bearskin, busby, shako', 440: 'beer bottle', 441: 'beer glass', 442: 'bell cote, bell cot', 443: 'bib', 444: 'bicycle-built-for-two, tandem bicycle, tandem', 445: 'bikini, two-piece', 446: 'binder, ring-binder', 447: 'binoculars, field glasses, opera glasses', 448: 'birdhouse', 449: 'boathouse', 450: 'bobsled, bobsleigh, bob', 451: 'bolo tie, bolo, bola tie, bola', 452: 'bonnet, poke bonnet', 453: 'bookcase', 454: 'bookshop, bookstore, bookstall', 455: 'bottlecap', 456: 'bow', 457: 'bow tie, bow-tie, bowtie', 458: 'brass, memorial tablet, plaque', 459: 'brassiere, bra, bandeau', 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty', 461: 'breastplate, aegis, egis', 462: 'broom', 463: 'bucket, pail', 464: 'buckle', 465: 'bulletproof vest', 466: 'bullet train, bullet', 467: 'butcher shop, meat market', 468: 'cab, hack, taxi, taxicab', 469: 'caldron, cauldron', 470: 'candle, taper, wax light', 471: 'cannon', 472: 'canoe', 473: 'can opener, tin opener', 474: 'cardigan', 475: 'car mirror', 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig', 477: "carpenter's kit, tool kit", 478: 'carton', 479: 'car wheel', 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', 481: 'cassette', 482: 'cassette player', 483: 'castle', 484: 'catamaran', 485: 'CD player', 486: 'cello, violoncello', 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone', 488: 'chain', 489: 'chainlink fence', 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', 491: 'chain saw, chainsaw', 492: 'chest', 493: 'chiffonier, commode', 494: 'chime, bell, gong', 495: 'china cabinet, china closet', 496: 'Christmas stocking', 497: 'church, church building', 498: 'cinema, movie theater, movie theatre, movie house, picture palace', 499: 'cleaver, meat cleaver, chopper', 500: 'cliff dwelling', 501: 'cloak', 502: 'clog, geta, patten, sabot', 503: 'co*cktail shaker', 504: 'coffee mug', 505: 'coffeepot', 506: 'coil, spiral, volute, whorl, helix', 507: 'combination lock', 508: 'computer keyboard, keypad', 509: 'confectionery, confectionary, candy store', 510: 'container ship, containership, container vessel', 511: 'convertible', 512: 'corkscrew, bottle screw', 513: 'cornet, horn, trumpet, trump', 514: 'cowboy boot', 515: 'cowboy hat, ten-gallon hat', 516: 'cradle', 517: 'crane', 518: 'crash helmet', 519: 'crate', 520: 'crib, cot', 521: 'Crock Pot', 522: 'croquet ball', 523: 'crutch', 524: 'cuirass', 525: 'dam, dike, dyke', 526: 'desk', 527: 'desktop computer', 528: 'dial telephone, dial phone', 529: 'diaper, nappy, napkin', 530: 'digital clock', 531: 'digital watch', 532: 'dining table, board', 533: 'dishrag, dishcloth', 534: 'dishwasher, dish washer, dishwashing machine', 535: 'disk brake, disc brake', 536: 'dock, dockage, docking facility', 537: 'dogsled, dog sled, dog sleigh', 538: 'dome', 539: 'doormat, welcome mat', 540: 'drilling platform, offshore rig', 541: 'drum, membranophone, tympan', 542: 'drumstick', 543: 'dumbbell', 544: 'Dutch oven', 545: 'electric fan, blower', 546: 'electric guitar', 547: 'electric locomotive', 548: 'entertainment center', 549: 'envelope', 550: 'espresso maker', 551: 'face powder', 552: 'feather boa, boa', 553: 'file, file cabinet, filing cabinet', 554: 'fireboat', 555: 'fire engine, fire truck', 556: 'fire screen, fireguard', 557: 'flagpole, flagstaff', 558: 'flute, transverse flute', 559: 'folding chair', 560: 'football helmet', 561: 'forklift', 562: 'fountain', 563: 'fountain pen', 564: 'four-poster', 565: 'freight car', 566: 'French horn, horn', 567: 'frying pan, frypan, skillet', 568: 'fur coat', 569: 'garbage truck, dustcart', 570: 'gasmask, respirator, gas helmet', 571: 'gas pump, gasoline pump, petrol pump, island dispenser', 572: 'goblet', 573: 'go-kart', 574: 'golf ball', 575: 'golfcart, golf cart', 576: 'gondola', 577: 'gong, tam-tam', 578: 'gown', 579: 'grand piano, grand', 580: 'greenhouse, nursery, glasshouse', 581: 'grille, radiator grille', 582: 'grocery store, grocery, food market, market', 583: 'guillotine', 584: 'hair slide', 585: 'hair spray', 586: 'half track', 587: 'hammer', 588: 'hamper', 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier', 590: 'hand-held computer, hand-held microcomputer', 591: 'handkerchief, hankie, hanky, hankey', 592: 'hard disc, hard disk, fixed disk', 593: 'harmonica, mouth organ, harp, mouth harp', 594: 'harp', 595: 'harvester, reaper', 596: 'hatchet', 597: 'holster', 598: 'home theater, home theatre', 599: 'honeycomb', 600: 'hook, claw', 601: 'hoopskirt, crinoline', 602: 'horizontal bar, high bar', 603: 'horse cart, horse-cart', 604: 'hourglass', 605: 'iPod', 606: 'iron, smoothing iron', 607: "jack-o'-lantern", 608: 'jean, blue jean, denim', 609: 'jeep, landrover', 610: 'jersey, T-shirt, tee shirt', 611: 'jigsaw puzzle', 612: 'jinrikisha, ricksha, rickshaw', 613: 'joystick', 614: 'kimono', 615: 'knee pad', 616: 'knot', 617: 'lab coat, laboratory coat', 618: 'ladle', 619: 'lampshade, lamp shade', 620: 'laptop, laptop computer', 621: 'lawn mower, mower', 622: 'lens cap, lens cover', 623: 'letter opener, paper knife, paperknife', 624: 'library', 625: 'lifeboat', 626: 'lighter, light, igniter, ignitor', 627: 'limousine, limo', 628: 'liner, ocean liner', 629: 'lipstick, lip rouge', 630: 'Loafer', 631: 'lotion', 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', 633: "loupe, jeweler's loupe", 634: 'lumbermill, sawmill', 635: 'magnetic compass', 636: 'mailbag, postbag', 637: 'mailbox, letter box', 638: 'maillot', 639: 'maillot, tank suit', 640: 'manhole cover', 641: 'maraca', 642: 'marimba, xylophone', 643: 'mask', 644: 'matchstick', 645: 'maypole', 646: 'maze, labyrinth', 647: 'measuring cup', 648: 'medicine chest, medicine cabinet', 649: 'megalith, megalithic structure', 650: 'microphone, mike', 651: 'microwave, microwave oven', 652: 'military uniform', 653: 'milk can', 654: 'minibus', 655: 'miniskirt, mini', 656: 'minivan', 657: 'missile', 658: 'mitten', 659: 'mixing bowl', 660: 'mobile home, manufactured home', 661: 'Model T', 662: 'modem', 663: 'monastery', 664: 'monitor', 665: 'moped', 666: 'mortar', 667: 'mortarboard', 668: 'mosque', 669: 'mosquito net', 670: 'motor scooter, scooter', 671: 'mountain bike, all-terrain bike, off-roader', 672: 'mountain tent', 673: 'mouse, computer mouse', 674: 'mousetrap', 675: 'moving van', 676: 'muzzle', 677: 'nail', 678: 'neck brace', 679: 'necklace', 680: 'nipple', 681: 'notebook, notebook computer', 682: 'obelisk', 683: 'oboe, hautboy, hautbois', 684: 'ocarina, sweet potato', 685: 'odometer, hodometer, mileometer, milometer', 686: 'oil filter', 687: 'organ, pipe organ', 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO', 689: 'overskirt', 690: 'oxcart', 691: 'oxygen mask', 692: 'packet', 693: 'paddle, boat paddle', 694: 'paddlewheel, paddle wheel', 695: 'padlock', 696: 'paintbrush', 697: "pajama, pyjama, pj's, jammies", 698: 'palace', 699: 'panpipe, pandean pipe, syrinx', 700: 'paper towel', 701: 'parachute, chute', 702: 'parallel bars, bars', 703: 'park bench', 704: 'parking meter', 705: 'passenger car, coach, carriage', 706: 'patio, terrace', 707: 'pay-phone, pay-station', 708: 'pedestal, plinth, footstall', 709: 'pencil box, pencil case', 710: 'pencil sharpener', 711: 'perfume, essence', 712: 'Petri dish', 713: 'photocopier', 714: 'pick, plectrum, plectron', 715: 'pickelhaube', 716: 'picket fence, paling', 717: 'pickup, pickup truck', 718: 'pier', 719: 'piggy bank, penny bank', 720: 'pill bottle', 721: 'pillow', 722: 'ping-pong ball', 723: 'pinwheel', 724: 'pirate, pirate ship', 725: 'pitcher, ewer', 726: "plane, carpenter's plane, woodworking plane", 727: 'planetarium', 728: 'plastic bag', 729: 'plate rack', 730: 'plow, plough', 731: "plunger, plumber's helper", 732: 'Polaroid camera, Polaroid Land camera', 733: 'pole', 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', 735: 'poncho', 736: 'pool table, billiard table, snooker table', 737: 'pop bottle, soda bottle', 738: 'pot, flowerpot', 739: "potter's wheel", 740: 'power drill', 741: 'prayer rug, prayer mat', 742: 'printer', 743: 'prison, prison house', 744: 'projectile, missile', 745: 'projector', 746: 'puck, hockey puck', 747: 'punching bag, punch bag, punching ball, punchball', 748: 'purse', 749: 'quill, quill pen', 750: 'quilt, comforter, comfort, puff', 751: 'racer, race car, racing car', 752: 'racket, racquet', 753: 'radiator', 754: 'radio, wireless', 755: 'radio telescope, radio reflector', 756: 'rain barrel', 757: 'recreational vehicle, RV, R.V.', 758: 'reel', 759: 'reflex camera', 760: 'refrigerator, icebox', 761: 'remote control, remote', 762: 'restaurant, eating house, eating place, eatery', 763: 'revolver, six-gun, six-shooter', 764: 'rifle', 765: 'rocking chair, rocker', 766: 'rotisserie', 767: 'rubber eraser, rubber, pencil eraser', 768: 'rugby ball', 769: 'rule, ruler', 770: 'running shoe', 771: 'safe', 772: 'safety pin', 773: 'saltshaker, salt shaker', 774: 'sandal', 775: 'sarong', 776: 'sax, saxophone', 777: 'scabbard', 778: 'scale, weighing machine', 779: 'school bus', 780: 'schooner', 781: 'scoreboard', 782: 'screen, CRT screen', 783: 'screw', 784: 'screwdriver', 785: 'seat belt, seatbelt', 786: 'sewing machine', 787: 'shield, buckler', 788: 'shoe shop, shoe-shop, shoe store', 789: 'shoji', 790: 'shopping basket', 791: 'shopping cart', 792: 'shovel', 793: 'shower cap', 794: 'shower curtain', 795: 'ski', 796: 'ski mask', 797: 'sleeping bag', 798: 'slide rule, slipstick', 799: 'sliding door', 800: 'slot, one-armed bandit', 801: 'snorkel', 802: 'snowmobile', 803: 'snowplow, snowplough', 804: 'soap dispenser', 805: 'soccer ball', 806: 'sock', 807: 'solar dish, solar collector, solar furnace', 808: 'sombrero', 809: 'soup bowl', 810: 'space bar', 811: 'space heater', 812: 'space shuttle', 813: 'spatula', 814: 'speedboat', 815: "spider web, spider's web", 816: 'spindle', 817: 'sports car, sport car', 818: 'spotlight, spot', 819: 'stage', 820: 'steam locomotive', 821: 'steel arch bridge', 822: 'steel drum', 823: 'stethoscope', 824: 'stole', 825: 'stone wall', 826: 'stopwatch, stop watch', 827: 'stove', 828: 'strainer', 829: 'streetcar, tram, tramcar, trolley, trolley car', 830: 'stretcher', 831: 'studio couch, day bed', 832: 'stupa, tope', 833: 'submarine, pigboat, sub, U-boat', 834: 'suit, suit of clothes', 835: 'sundial', 836: 'sunglass', 837: 'sunglasses, dark glasses, shades', 838: 'sunscreen, sunblock, sun blocker', 839: 'suspension bridge', 840: 'swab, swob, mop', 841: 'sweatshirt', 842: 'swimming trunks, bathing trunks', 843: 'swing', 844: 'switch, electric switch, electrical switch', 845: 'syringe', 846: 'table lamp', 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle', 848: 'tape player', 849: 'teapot', 850: 'teddy, teddy bear', 851: 'television, television system', 852: 'tennis ball', 853: 'thatch, thatched roof', 854: 'theater curtain, theatre curtain', 855: 'thimble', 856: 'thresher, thrasher, threshing machine', 857: 'throne', 858: 'tile roof', 859: 'toaster', 860: 'tobacco shop, tobacconist shop, tobacconist', 861: 'toilet seat', 862: 'torch', 863: 'totem pole', 864: 'tow truck, tow car, wrecker', 865: 'toyshop', 866: 'tractor', 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', 868: 'tray', 869: 'trench coat', 870: 'tricycle, trike, velocipede', 871: 'trimaran', 872: 'tripod', 873: 'triumphal arch', 874: 'trolleybus, trolley coach, trackless trolley', 875: 'trombone', 876: 'tub, vat', 877: 'turnstile', 878: 'typewriter keyboard', 879: 'umbrella', 880: 'unicycle, monocycle', 881: 'upright, upright piano', 882: 'vacuum, vacuum cleaner', 883: 'vase', 884: 'vault', 885: 'velvet', 886: 'vending machine', 887: 'vestment', 888: 'viaduct', 889: 'violin, fiddle', 890: 'volleyball', 891: 'waffle iron', 892: 'wall clock', 893: 'wallet, billfold, notecase, pocketbook', 894: 'wardrobe, closet, press', 895: 'warplane, military plane', 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin', 897: 'washer, automatic washer, washing machine', 898: 'water bottle', 899: 'water jug', 900: 'water tower', 901: 'whiskey jug', 902: 'whistle', 903: 'wig', 904: 'window screen', 905: 'window shade', 906: 'Windsor tie', 907: 'wine bottle', 908: 'wing', 909: 'wok', 910: 'wooden spoon', 911: 'wool, woolen, woollen', 912: 'worm fence, snake fence, snake-rail fence, Virginia fence', 913: 'wreck', 914: 'yawl', 915: 'yurt', 916: 'web site, website, internet site, site', 917: 'comic book', 918: 'crossword puzzle, crossword', 919: 'street sign', 920: 'traffic light, traffic signal, stoplight', 921: 'book jacket, dust cover, dust jacket, dust wrapper', 922: 'menu', 923: 'plate', 924: 'guacamole', 925: 'consomme', 926: 'hot pot, hotpot', 927: 'trifle', 928: 'ice cream, icecream', 929: 'ice lolly, lolly, lollipop, popsicle', 930: 'French loaf', 931: 'bagel, beigel', 932: 'pretzel', 933: 'cheeseburger', 934: 'hotdog, hot dog, red hot', 935: 'mashed potato', 936: 'head cabbage', 937: 'broccoli', 938: 'cauliflower', 939: 'zucchini, courgette', 940: 'spaghetti squash', 941: 'acorn squash', 942: 'butternut squash', 943: 'cucumber, cuke', 944: 'artichoke, globe artichoke', 945: 'bell pepper', 946: 'cardoon', 947: 'mushroom', 948: 'Granny Smith', 949: 'strawberry', 950: 'orange', 951: 'lemon', 952: 'fig', 953: 'pineapple, ananas', 954: 'banana', 955: 'jackfruit, jak, jack', 956: 'custard apple', 957: 'pomegranate', 958: 'hay', 959: 'carbonara', 960: 'chocolate sauce, chocolate syrup', 961: 'dough', 962: 'meat loaf, meatloaf', 963: 'pizza, pizza pie', 964: 'potpie', 965: 'burrito', 966: 'red wine', 967: 'espresso', 968: 'cup', 969: 'eggnog', 970: 'alp', 971: 'bubble', 972: 'cliff, drop, drop-off', 973: 'coral reef', 974: 'geyser', 975: 'lakeside, lakeshore', 976: 'promontory, headland, head, foreland', 977: 'sandbar, sand bar', 978: 'seashore, coast, seacoast, sea-coast', 979: 'valley, vale', 980: 'volcano', 981: 'ballplayer, baseball player', 982: 'groom, bridegroom', 983: 'scuba diver', 984: 'rapeseed', 985: 'daisy', 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", 987: 'corn', 988: 'acorn', 989: 'hip, rose hip, rosehip', 990: 'buckeye, horse chestnut, conker', 991: 'coral fungus', 992: 'agaric', 993: 'gyromitra', 994: 'stinkhorn, carrion fungus', 995: 'earthstar', 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', 997: 'bolete', 998: 'ear, spike, capitulum', 999: 'toilet tissue, toilet paper, bathroom tissue'}

運用実績

1.LeNet5

basenet: lenet5 (image size: 32 * 32 * 3)dataset: cifarlen(dataset): 50000, iter_size: 1562 batch_size: 32optim: SGDscheduler: MultiStepLRmilestones: [15, 20, 30]weight_decay: 1e-4gamma: 0.1momentum: 0.9lr: 0.01epoch: 30

合計

エポック 平均トップ 1 アクセス (%) 上位 5 の平均アクセス (%)
30 0h11m44s 62.21 95.97

2.アレックスネット

basenet: alexnet (image size: 224 * 224 * 3)dataset: cifarlen(dataset): 50000, iter_size: 1562 batch_size: 32optim: SGDscheduler: MultiStepLRmilestones: [15, 20, 30]weight_decay: 1e-4gamma: 0.1momentum: 0.9lr:0.01epoch: 30 

合計

エポック 平均トップ 1 アクセス (%) 上位 5 の平均アクセス (%)
30 0h22m44s 86.27 99.0

3.VGG

basenet: vgg16 (image size: 224 * 224 * 3)dataset: cifarlen(dataset): 50000, iter_size: 1562 batch_size: 32optim: SGDscheduler: MultiStepLRmilestones: [15, 20, 30]weight_decay: 1e-4gamma: 0.1momentum: 0.9lr:0.01epoch: 30 

合計

エポック 平均トップ 1 アクセス (%) 上位 5 の平均アクセス (%)
30 1時間23分43秒 76.56 96.44

4.レスネット

basenet: resnet18dataset: ImageNetimage size: 224 * 224 * 3 (可自定义)batch_size: 32optim: SGDscheduler: MultiStepLRmilestones: [15, 20, 30]weight_decay: 1e-4gamma: 0.1momentum: 0.9lr:0.001epoch: 30
エポック番号 上位 1 件 (%) トップ 5 の割合 (%)
5 3時間49分35秒 50.21 75.59

次の章

CV + 深層学習——ネットワーク アーキテクチャ Pytorch 再現シリーズ——分類 (2: ResNeXt、GoogLeNet、MobileNet) CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (14)https://blog.csdn.net/XiaoyYidiaodiao/article/details/125692368?csdn_share_tail=%7B%22type%22 % 3A%22blog%22%2C%22rType%22%3A%22article%22%2C%22rId%22%3A%22125692368%22%2C%22source%22%3A%22XiaoyYidiaodiao%22%7D&ctrtid=yBcgN

[1] LeCun Y、Bottou L、Bengio Y、他。ドキュメント認識に適用される勾配ベースの学習[J]。IEEE の議事録、1998 年、86(11): 2278-2324。

[2] Krizhevsky A、Sutskever I、Hinton G E. 深い畳み込みニューラル ネットワークによる Imagenet 分類[J]。神経情報処理システムの進歩、2012、25。

[3] Simonyan K、Zisserman A. 大規模な画像認識のための非常に深い畳み込みネットワーク[J]。arXiv プレプリント arXiv:1409.1556, 2014.

[4] He K、Zhang X、Ren S、他。画像認識のための深層残差学習[C]//コンピューター ビジョンとパターン認識に関する IEEE 会議の議事録。2016: 770-778.

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (2024)

References

Top Articles
The Fayetteville News from Fayetteville, North Carolina
Buffalo Courier from Buffalo, New York
No Hard Feelings Showtimes Near Metropolitan Fiesta 5 Theatre
Compare Foods Wilson Nc
Uihc Family Medicine
How Many Cc's Is A 96 Cubic Inch Engine
Linkvertise Bypass 2023
Strange World Showtimes Near Cmx Downtown At The Gardens 16
What Was D-Day Weegy
Used Wood Cook Stoves For Sale Craigslist
Housework 2 Jab
Busty Bruce Lee
Best Suv In 2010
Love In The Air Ep 9 Eng Sub Dailymotion
Scenes from Paradise: Where to Visit Filming Locations Around the World - Paradise
Kürtçe Doğum Günü Sözleri
Po Box 35691 Canton Oh
10-Day Weather Forecast for Santa Cruz, CA - The Weather Channel | weather.com
Grayling Purnell Net Worth
Craigslist Missoula Atv
Ms Rabbit 305
Blue Rain Lubbock
Cvs El Salido
Orange Pill 44 291
Atdhe Net
Kcwi Tv Schedule
U Of Arizona Phonebook
Village
Valic Eremit
Rapv Springfield Ma
Colonial Executive Park - CRE Consultants
1145 Barnett Drive
When His Eyes Opened Chapter 3123
Evil Dead Rise Showtimes Near Sierra Vista Cinemas 16
Netspend Ssi Deposit Dates For 2022 November
Grave Digger Wynncraft
Harrison 911 Cad Log
Kristy Ann Spillane
Die wichtigsten E-Nummern
A Small Traveling Suitcase Figgerits
What Time Does Walmart Auto Center Open
The Vélodrome d'Hiver (Vél d'Hiv) Roundup
Pensacola Cars Craigslist
The Closest Walmart From My Location
The Banshees Of Inisherin Showtimes Near Reading Cinemas Town Square
Fifty Shades Of Gray 123Movies
Brandon Spikes Career Earnings
Trending mods at Kenshi Nexus
Rovert Wrestling
Charlotte North Carolina Craigslist Pets
Turning Obsidian into My Perfect Writing App – The Sweet Setup
E. 81 St. Deli Menu
Latest Posts
Article information

Author: Manual Maggio

Last Updated:

Views: 5934

Rating: 4.9 / 5 (49 voted)

Reviews: 88% of readers found this page helpful

Author information

Name: Manual Maggio

Birthday: 1998-01-20

Address: 359 Kelvin Stream, Lake Eldonview, MT 33517-1242

Phone: +577037762465

Job: Product Hospitality Supervisor

Hobby: Gardening, Web surfing, Video gaming, Amateur radio, Flag Football, Reading, Table tennis

Introduction: My name is Manual Maggio, I am a thankful, tender, adventurous, delightful, fantastic, proud, graceful person who loves writing and wants to share my knowledge and understanding with you.