dotproduct_attentionMLP.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import numpy as np
  2. import tensorflow as tf
  3. from tensorflow import keras
  4. from tensorflow.keras import layers
  5. import tensorflow_addons as tfa
  6. import matplotlib.pyplot as plt
  7. class PatchExtract(layers.Layer): #提取patch
  8. def __init__(self, patch_size, **kwargs):
  9. super(PatchExtract, self).__init__(**kwargs)
  10. self.patch_size = patch_size
  11. def call(self, images):
  12. batch_size = tf.shape(images)[0]
  13. patches = tf.image.extract_patches(
  14. images=images,
  15. sizes=(1, self.patch_size, self.patch_size, 1),
  16. strides=(1, self.patch_size, self.patch_size, 1),
  17. rates=(1, 1, 1, 1),
  18. padding="VALID",
  19. )
  20. patch_dim = patches.shape[-1]
  21. patch_num = patches.shape[1]
  22. return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))
  23. class PatchEmbedding(layers.Layer): #patch转化为地位矩阵embedding
  24. def __init__(self, num_patch, embed_dim, **kwargs):
  25. super(PatchEmbedding, self).__init__(**kwargs)
  26. self.num_patch = num_patch
  27. self.proj = layers.Dense(embed_dim)
  28. self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)
  29. def call(self, patch):
  30. pos = tf.range(start=0, limit=self.num_patch, delta=1)
  31. return self.proj(patch) + self.pos_embed(pos)
  32. def dotproduct_attention(
  33. x, dim, num_heads, dim_coefficient=4, attention_dropout=0, projection_dropout=0
  34. ): #点积计算attention值,输入为patch批次矩阵((2*2),数量,通道数),embedding_dim,
  35. #超参数num_heads(保证embedding_num和num_head是倍数关系)
  36. #和常规参数embedding系数、attention单元的dropout率(去掉一些运算过程中产生的参数,设为0是不去除)和映射层的dropout
  37. _, num_patch, channel = x.shape #获取patch数量和通道个数
  38. assert dim % num_heads == 0 #确保embedding_num和num_head是倍数关系
  39. num_heads = num_heads * dim_coefficient #
  40. x = layers.Dense(dim * dim_coefficient)(x) #定义一个网络层,执行的操作是func(input*kernel)+bias,这里的*是计算点积,
  41. #计算patches与embedding的点积并赋值给patches
  42. x = tf.reshape(
  43. x, shape=(-1, num_patch, num_heads, dim * dim_coefficient // num_heads)
  44. ) #将patches重新还原为原来的维度
  45. x = tf.transpose(x, perm=[0, 2, 1, 3]) #求pathes的转置(高和列),并将转置的矩阵赋值给patches
  46. attn = layers.Dense(dim // dim_coefficient)(x) #网络层,将转置后的patches计算点积产生attention向量
  47. attn = layers.Softmax(axis=2)(attn) #softmax函数,计算attention向量
  48. attn = attn / (1e-9 + tf.reduce_sum(attn, axis=-1, keepdims=True)) #计算attention值,tf.reduce_sum是按维度求和,
  49. #通过attention向量计算attention值
  50. attn = layers.Dropout(attention_dropout)(attn) #去除中间过程产生的参数
  51. x = layers.Dense(dim * dim_coefficient // num_heads)(attn) #计算patche和attention值的点积
  52. x = tf.transpose(x, perm=[0, 2, 1, 3]) #patches高和列转置
  53. x = tf.reshape(x, [-1, num_patch, dim * dim_coefficient]) #复原patches为原来的维度
  54. x = layers.Dense(dim)(x) #计算dim与patches的点积
  55. x = layers.Dropout(projection_dropout)(x) #去除中间过程产生的参数
  56. return x
  57. def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2): #喂给MLP
  58. x = layers.Dense(mlp_dim, activation=tf.nn.gelu)(x)
  59. x = layers.Dropout(drop_rate)(x)
  60. x = layers.Dense(embedding_dim)(x)
  61. x = layers.Dropout(drop_rate)(x)
  62. return x
  63. def transformer_encoder(
  64. x,
  65. embedding_dim,
  66. mlp_dim,
  67. num_heads,
  68. dim_coefficient,
  69. attention_dropout,
  70. projection_dropout,
  71. attention_type="dotproduct_attention",
  72. ): #encoder步骤
  73. residual_1 = x
  74. x = layers.LayerNormalization(epsilon=1e-5)(x)
  75. if attention_type == "dotproduct_attention":
  76. x = dotproduct_attention(
  77. x,
  78. embedding_dim,
  79. num_heads,
  80. dim_coefficient,
  81. attention_dropout,
  82. projection_dropout,
  83. )
  84. elif attention_type == "self_attention":
  85. x = layers.MultiHeadAttention(
  86. num_heads=num_heads, key_dim=embedding_dim, dropout=attention_dropout
  87. )(x, x)
  88. x = layers.add([x, residual_1])
  89. residual_2 = x
  90. x = layers.LayerNormalization(epsilon=1e-5)(x)
  91. x = mlp(x, embedding_dim, mlp_dim)
  92. x = layers.add([x, residual_2])
  93. return x
  94. def get_model(attention_type="dotproduct_attention"):
  95. inputs = layers.Input(shape=input_shape)
  96. x = data_augmentation(inputs)
  97. x = PatchExtract(patch_size)(x)
  98. x = PatchEmbedding(num_patches, embedding_dim)(x)
  99. for _ in range(num_transformer_blocks):
  100. x = transformer_encoder(
  101. x,
  102. embedding_dim,
  103. mlp_dim,
  104. num_heads,
  105. dim_coefficient,
  106. attention_dropout,
  107. projection_dropout,
  108. attention_type,
  109. )
  110. x = layers.GlobalAvgPool1D()(x)
  111. outputs = layers.Dense(num_classes, activation="softmax")(x)
  112. model = keras.Model(inputs=inputs, outputs=outputs)
  113. return model
  114. #加载数据
  115. num_classes = 100
  116. input_shape = (32, 32, 3)
  117. (x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
  118. y_train = keras.utils.to_categorical(y_train, num_classes)
  119. y_test = keras.utils.to_categorical(y_test, num_classes)
  120. print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
  121. print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
  122. #设置超参数
  123. weight_decay = 0.0001
  124. learning_rate = 0.001
  125. label_smoothing = 0.1
  126. validation_split = 0.2
  127. batch_size = 128
  128. num_epochs = 50
  129. patch_size = 2 # 从原图提取patch的窗口大小2*2
  130. num_patches = (input_shape[0] // patch_size) ** 2 # patch数量
  131. embedding_dim = 64 # 隐藏单元数量
  132. mlp_dim = 64
  133. dim_coefficient = 4
  134. num_heads = 4
  135. attention_dropout = 0.2
  136. projection_dropout = 0.2
  137. num_transformer_blocks = 8 #transformer层的重复次数
  138. print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
  139. print(f"Patches per image: {num_patches}")
  140. #数据增强
  141. data_augmentation = keras.Sequential(
  142. [
  143. layers.Normalization(),
  144. layers.RandomFlip("horizontal"),
  145. layers.RandomRotation(factor=0.1),
  146. layers.RandomContrast(factor=0.1),
  147. layers.RandomZoom(height_factor=0.2, width_factor=0.2),
  148. ],
  149. name="data_augmentation",
  150. )
  151. #计算训练集的平均值和方差,便于正则化训练集
  152. data_augmentation.layers[0].adapt(x_train)
  153. #开始调用模型
  154. model = get_model(attention_type="dotproduct_attention")
  155. model.compile(
  156. loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
  157. optimizer=tfa.optimizers.AdamW(
  158. learning_rate=learning_rate, weight_decay=weight_decay
  159. ),
  160. metrics=[
  161. keras.metrics.CategoricalAccuracy(name="accuracy"),
  162. keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
  163. ],
  164. )
  165. history = model.fit(
  166. x_train,
  167. y_train,
  168. batch_size=batch_size,
  169. epochs=num_epochs,
  170. validation_split=validation_split,
  171. )
  172. #画混淆矩阵
  173. from sklearn.metrics import confusion_matrix
  174. import itertools
  175. plt.rcParams['figure.figsize'] = [12,12]
  176. def plot_confusion_matrix(cm, classes,
  177. normalize=False,
  178. title='Confusion matrix',
  179. cmap=plt.cm.Blues):
  180. if normalize:
  181. cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  182. print("Normalized confusion matrix")
  183. else:
  184. print('Confusion matrix, without normalization')
  185. print(cm)
  186. plt.imshow(cm, interpolation='nearest', cmap=cmap)
  187. plt.title(title)
  188. plt.colorbar()
  189. tick_marks = np.arange(len(classes))
  190. plt.xticks(tick_marks[0::2], classes[0::2], rotation=0)
  191. plt.yticks(tick_marks[0::2], classes[0::2])
  192. '''
  193. fmt = '.2f' if normalize else 'd'
  194. thresh = cm.max() / 2.
  195. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  196. plt.text(j, i, format(cm[i, j], fmt),
  197. horizontalalignment="center",
  198. color="white" if cm[i, j] > thresh else "black")
  199. '''
  200. plt.tight_layout()
  201. plt.ylabel('True label')
  202. plt.xlabel('Predicted label')
  203. plt.savefig('./picture/confusion_matrix.jpeg',dpi=1200,bbox_inches='tight')
  204. plt.show()
  205. p_test = model.predict(x_test).argmax(axis=1)
  206. cm = confusion_matrix(y_test.argmax(axis=1), p_test)
  207. plot_confusion_matrix(cm, list(range(100)))
  208. #画损失函数
  209. plt.plot(history.history["loss"], label="train_loss")
  210. plt.plot(history.history["val_loss"], label="val_loss")
  211. plt.xlabel("Epochs")
  212. plt.ylabel("Loss")
  213. plt.title("Train and Validation Losses Over Epochs", fontsize=14)
  214. plt.legend()
  215. plt.grid()
  216. plt.savefig('./picture/loss_function.jpeg',dpi=800,bbox_inches='tight')
  217. plt.show()